Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		RockeyCoss
		
	commited on
		
		
					Commit 
							
							·
						
						51f6859
	
1
								Parent(s):
							
							54090b5
								
add code files”
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +125 -0
 - LICENSE +203 -0
 - app.py +133 -0
 - assets/img1.jpg +0 -0
 - assets/img2.jpg +0 -0
 - assets/img3.jpg +0 -0
 - assets/img4.jpg +0 -0
 - flagged/Input/tmpaytsmk0e.jpg +0 -0
 - flagged/Output/tmpgs59m7u_.png +0 -0
 - flagged/log.csv +2 -0
 - mmdet/__init__.py +29 -0
 - mmdet/apis/__init__.py +12 -0
 - mmdet/apis/inference.py +258 -0
 - mmdet/apis/test.py +209 -0
 - mmdet/apis/train.py +246 -0
 - mmdet/core/__init__.py +10 -0
 - mmdet/core/anchor/__init__.py +14 -0
 - mmdet/core/anchor/anchor_generator.py +866 -0
 - mmdet/core/anchor/builder.py +19 -0
 - mmdet/core/anchor/point_generator.py +263 -0
 - mmdet/core/anchor/utils.py +72 -0
 - mmdet/core/bbox/__init__.py +28 -0
 - mmdet/core/bbox/assigners/__init__.py +25 -0
 - mmdet/core/bbox/assigners/approx_max_iou_assigner.py +146 -0
 - mmdet/core/bbox/assigners/ascend_assign_result.py +34 -0
 - mmdet/core/bbox/assigners/ascend_max_iou_assigner.py +178 -0
 - mmdet/core/bbox/assigners/assign_result.py +206 -0
 - mmdet/core/bbox/assigners/atss_assigner.py +234 -0
 - mmdet/core/bbox/assigners/base_assigner.py +10 -0
 - mmdet/core/bbox/assigners/center_region_assigner.py +336 -0
 - mmdet/core/bbox/assigners/grid_assigner.py +156 -0
 - mmdet/core/bbox/assigners/hungarian_assigner.py +139 -0
 - mmdet/core/bbox/assigners/mask_hungarian_assigner.py +125 -0
 - mmdet/core/bbox/assigners/max_iou_assigner.py +218 -0
 - mmdet/core/bbox/assigners/point_assigner.py +134 -0
 - mmdet/core/bbox/assigners/region_assigner.py +222 -0
 - mmdet/core/bbox/assigners/sim_ota_assigner.py +257 -0
 - mmdet/core/bbox/assigners/task_aligned_assigner.py +151 -0
 - mmdet/core/bbox/assigners/uniform_assigner.py +135 -0
 - mmdet/core/bbox/builder.py +21 -0
 - mmdet/core/bbox/coder/__init__.py +15 -0
 - mmdet/core/bbox/coder/base_bbox_coder.py +18 -0
 - mmdet/core/bbox/coder/bucketing_bbox_coder.py +351 -0
 - mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +392 -0
 - mmdet/core/bbox/coder/distance_point_bbox_coder.py +63 -0
 - mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py +216 -0
 - mmdet/core/bbox/coder/pseudo_bbox_coder.py +19 -0
 - mmdet/core/bbox/coder/tblr_bbox_coder.py +206 -0
 - mmdet/core/bbox/coder/yolo_bbox_coder.py +83 -0
 - mmdet/core/bbox/demodata.py +42 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Byte-compiled / optimized / DLL files
         
     | 
| 2 | 
         
            +
            __pycache__/
         
     | 
| 3 | 
         
            +
            *.py[cod]
         
     | 
| 4 | 
         
            +
            *$py.class
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # C extensions
         
     | 
| 7 | 
         
            +
            *.so
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Distribution / packaging
         
     | 
| 10 | 
         
            +
            .Python
         
     | 
| 11 | 
         
            +
            build/
         
     | 
| 12 | 
         
            +
            develop-eggs/
         
     | 
| 13 | 
         
            +
            dist/
         
     | 
| 14 | 
         
            +
            downloads/
         
     | 
| 15 | 
         
            +
            eggs/
         
     | 
| 16 | 
         
            +
            .eggs/
         
     | 
| 17 | 
         
            +
            lib/
         
     | 
| 18 | 
         
            +
            lib64/
         
     | 
| 19 | 
         
            +
            parts/
         
     | 
| 20 | 
         
            +
            sdist/
         
     | 
| 21 | 
         
            +
            var/
         
     | 
| 22 | 
         
            +
            wheels/
         
     | 
| 23 | 
         
            +
            *.egg-info/
         
     | 
| 24 | 
         
            +
            .installed.cfg
         
     | 
| 25 | 
         
            +
            *.egg
         
     | 
| 26 | 
         
            +
            MANIFEST
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # PyInstaller
         
     | 
| 29 | 
         
            +
            #  Usually these files are written by a python script from a template
         
     | 
| 30 | 
         
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         
     | 
| 31 | 
         
            +
            *.manifest
         
     | 
| 32 | 
         
            +
            *.spec
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            # Installer logs
         
     | 
| 35 | 
         
            +
            pip-log.txt
         
     | 
| 36 | 
         
            +
            pip-delete-this-directory.txt
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            # Unit test / coverage reports
         
     | 
| 39 | 
         
            +
            htmlcov/
         
     | 
| 40 | 
         
            +
            .tox/
         
     | 
| 41 | 
         
            +
            .coverage
         
     | 
| 42 | 
         
            +
            .coverage.*
         
     | 
| 43 | 
         
            +
            .cache
         
     | 
| 44 | 
         
            +
            nosetests.xml
         
     | 
| 45 | 
         
            +
            coverage.xml
         
     | 
| 46 | 
         
            +
            *.cover
         
     | 
| 47 | 
         
            +
            .hypothesis/
         
     | 
| 48 | 
         
            +
            .pytest_cache/
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            # Translations
         
     | 
| 51 | 
         
            +
            *.mo
         
     | 
| 52 | 
         
            +
            *.pot
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            # Django stuff:
         
     | 
| 55 | 
         
            +
            *.log
         
     | 
| 56 | 
         
            +
            local_settings.py
         
     | 
| 57 | 
         
            +
            db.sqlite3
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            # Flask stuff:
         
     | 
| 60 | 
         
            +
            instance/
         
     | 
| 61 | 
         
            +
            .webassets-cache
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            # Scrapy stuff:
         
     | 
| 64 | 
         
            +
            .scrapy
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            # Sphinx documentation
         
     | 
| 67 | 
         
            +
            docs/en/_build/
         
     | 
| 68 | 
         
            +
            docs/zh_cn/_build/
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            # PyBuilder
         
     | 
| 71 | 
         
            +
            target/
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            # Jupyter Notebook
         
     | 
| 74 | 
         
            +
            .ipynb_checkpoints
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            # pyenv
         
     | 
| 77 | 
         
            +
            .python-version
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            # celery beat schedule file
         
     | 
| 80 | 
         
            +
            celerybeat-schedule
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            # SageMath parsed files
         
     | 
| 83 | 
         
            +
            *.sage.py
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # Environments
         
     | 
| 86 | 
         
            +
            .env
         
     | 
| 87 | 
         
            +
            .venv
         
     | 
| 88 | 
         
            +
            env/
         
     | 
| 89 | 
         
            +
            venv/
         
     | 
| 90 | 
         
            +
            ENV/
         
     | 
| 91 | 
         
            +
            env.bak/
         
     | 
| 92 | 
         
            +
            venv.bak/
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            # Spyder project settings
         
     | 
| 95 | 
         
            +
            .spyderproject
         
     | 
| 96 | 
         
            +
            .spyproject
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            # Rope project settings
         
     | 
| 99 | 
         
            +
            .ropeproject
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            # mkdocs documentation
         
     | 
| 102 | 
         
            +
            /site
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            # mypy
         
     | 
| 105 | 
         
            +
            .mypy_cache/
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            data/
         
     | 
| 108 | 
         
            +
            data
         
     | 
| 109 | 
         
            +
            .vscode
         
     | 
| 110 | 
         
            +
            .idea
         
     | 
| 111 | 
         
            +
            .DS_Store
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            # custom
         
     | 
| 114 | 
         
            +
            *.pkl
         
     | 
| 115 | 
         
            +
            *.pkl.json
         
     | 
| 116 | 
         
            +
            *.log.json
         
     | 
| 117 | 
         
            +
            docs/modelzoo_statistics.md
         
     | 
| 118 | 
         
            +
            mmdet/.mim
         
     | 
| 119 | 
         
            +
            work_dirs/
         
     | 
| 120 | 
         
            +
            ckpt/
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            # Pytorch
         
     | 
| 123 | 
         
            +
            *.pth
         
     | 
| 124 | 
         
            +
            *.py~
         
     | 
| 125 | 
         
            +
            *.sh~
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,203 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Copyright 2018-2023 OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
                                             Apache License
         
     | 
| 4 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 5 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
               1. Definitions.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 12 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 15 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 18 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 19 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 20 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 21 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 22 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 23 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 26 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 29 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 30 | 
         
            +
                  source, and configuration files.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 33 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 34 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 35 | 
         
            +
                  and conversions to other media types.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 38 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 39 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 40 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 43 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 44 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 45 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 46 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 47 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 48 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 51 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 52 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 53 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 54 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 55 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 56 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 57 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 58 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 59 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 60 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 61 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 62 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 65 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 66 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 69 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 70 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 71 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 72 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 73 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 76 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 77 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 78 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 79 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 80 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 81 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 82 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 83 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 84 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 85 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 86 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 87 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 88 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 89 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 92 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 93 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 94 | 
         
            +
                  meet the following conditions:
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 97 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 100 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 103 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 104 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 105 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 106 | 
         
            +
                      the Derivative Works; and
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 109 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 110 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 111 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 112 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 113 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 114 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 115 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 116 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 117 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 118 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 119 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 120 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 121 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 122 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 123 | 
         
            +
                      as modifying the License.
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 126 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 127 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 128 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 129 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 130 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 133 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 134 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 135 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 136 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 137 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 138 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 141 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 142 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 143 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 146 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 147 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 148 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 149 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 150 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 151 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 152 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 153 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 156 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 157 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 158 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 159 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 160 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 161 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 162 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 163 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 164 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 165 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 168 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 169 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 170 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 171 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 172 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 173 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 174 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 175 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 176 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 183 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 184 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 185 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 186 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 187 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 188 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 189 | 
         
            +
                  identification within third-party archives.
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Copyright 2018-2023 OpenMMLab.
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 194 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 195 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 200 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 201 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 202 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 203 | 
         
            +
               limitations under the License.
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,133 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from collections import OrderedDict
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from mmcv import Config
         
     | 
| 7 | 
         
            +
            from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from mmdet.apis import init_detector, inference_detector
         
     | 
| 10 | 
         
            +
            from mmdet.datasets import (CocoDataset)
         
     | 
| 11 | 
         
            +
            from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
         
     | 
| 12 | 
         
            +
                                     update_data_root)
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import gradio as gr
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            config_dict = OrderedDict([('swin-l-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py'),
         
     | 
| 17 | 
         
            +
                                       ('swin-l-hdetr_sam-vit-l', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
         
     | 
| 18 | 
         
            +
                                       ('swin-l-hdetr_sam-vit-h', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
         
     | 
| 19 | 
         
            +
                                       ('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
         
     | 
| 20 | 
         
            +
                                       ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
         
     | 
| 21 | 
         
            +
                                       (
         
     | 
| 22 | 
         
            +
                                       'focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')])
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def inference(img, config):
         
     | 
| 26 | 
         
            +
                if img is None:
         
     | 
| 27 | 
         
            +
                    return None
         
     | 
| 28 | 
         
            +
                config = config_dict[config]
         
     | 
| 29 | 
         
            +
                cfg = Config.fromfile(config)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                # replace the ${key} with the value of cfg.key
         
     | 
| 32 | 
         
            +
                cfg = replace_cfg_vals(cfg)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                # update data root according to MMDET_DATASETS
         
     | 
| 35 | 
         
            +
                update_data_root(cfg)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                cfg = compat_cfg(cfg)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                # set multi-process settings
         
     | 
| 40 | 
         
            +
                setup_multi_processes(cfg)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                # import modules from plguin/xx, registry will be updated
         
     | 
| 43 | 
         
            +
                if hasattr(cfg, 'plugin'):
         
     | 
| 44 | 
         
            +
                    if cfg.plugin:
         
     | 
| 45 | 
         
            +
                        import importlib
         
     | 
| 46 | 
         
            +
                        if hasattr(cfg, 'plugin_dir'):
         
     | 
| 47 | 
         
            +
                            plugin_dir = cfg.plugin_dir
         
     | 
| 48 | 
         
            +
                            _module_dir = os.path.dirname(plugin_dir)
         
     | 
| 49 | 
         
            +
                            _module_dir = _module_dir.split('/')
         
     | 
| 50 | 
         
            +
                            _module_path = _module_dir[0]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                            for m in _module_dir[1:]:
         
     | 
| 53 | 
         
            +
                                _module_path = _module_path + '.' + m
         
     | 
| 54 | 
         
            +
                            print(_module_path)
         
     | 
| 55 | 
         
            +
                            plg_lib = importlib.import_module(_module_path)
         
     | 
| 56 | 
         
            +
                        else:
         
     | 
| 57 | 
         
            +
                            # import dir is the dirpath for the config file
         
     | 
| 58 | 
         
            +
                            _module_dir = os.path.dirname(config)
         
     | 
| 59 | 
         
            +
                            _module_dir = _module_dir.split('/')
         
     | 
| 60 | 
         
            +
                            _module_path = _module_dir[0]
         
     | 
| 61 | 
         
            +
                            for m in _module_dir[1:]:
         
     | 
| 62 | 
         
            +
                                _module_path = _module_path + '.' + m
         
     | 
| 63 | 
         
            +
                            # print(_module_path)
         
     | 
| 64 | 
         
            +
                            plg_lib = importlib.import_module(_module_path)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                # set cudnn_benchmark
         
     | 
| 67 | 
         
            +
                if cfg.get('cudnn_benchmark', False):
         
     | 
| 68 | 
         
            +
                    torch.backends.cudnn.benchmark = True
         
     | 
| 69 | 
         
            +
                if IS_CUDA_AVAILABLE or IS_MLU_AVAILABLE:
         
     | 
| 70 | 
         
            +
                    device = "cuda"
         
     | 
| 71 | 
         
            +
                else:
         
     | 
| 72 | 
         
            +
                    device = "cpu"
         
     | 
| 73 | 
         
            +
                model = init_detector(cfg, None, device=device)
         
     | 
| 74 | 
         
            +
                model.CLASSES = CocoDataset.CLASSES
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                results = inference_detector(model, img)
         
     | 
| 77 | 
         
            +
                visualize = model.show_result(
         
     | 
| 78 | 
         
            +
                    img,
         
     | 
| 79 | 
         
            +
                    results,
         
     | 
| 80 | 
         
            +
                    bbox_color=CocoDataset.PALETTE,
         
     | 
| 81 | 
         
            +
                    text_color=CocoDataset.PALETTE,
         
     | 
| 82 | 
         
            +
                    mask_color=CocoDataset.PALETTE,
         
     | 
| 83 | 
         
            +
                    show=False,
         
     | 
| 84 | 
         
            +
                    out_file=None,
         
     | 
| 85 | 
         
            +
                    score_thr=0.3
         
     | 
| 86 | 
         
            +
                )
         
     | 
| 87 | 
         
            +
                del model
         
     | 
| 88 | 
         
            +
                return visualize
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            description = """
         
     | 
| 92 | 
         
            +
            #  <center>Prompt Segment Anything (zero-shot instance segmentation demo)</center>
         
     | 
| 93 | 
         
            +
            Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
         
     | 
| 94 | 
         
            +
            You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
         
     | 
| 95 | 
         
            +
            """
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def main():
         
     | 
| 99 | 
         
            +
                with gr.Blocks() as demo:
         
     | 
| 100 | 
         
            +
                    gr.Markdown(description)
         
     | 
| 101 | 
         
            +
                    with gr.Column():
         
     | 
| 102 | 
         
            +
                        with gr.Row():
         
     | 
| 103 | 
         
            +
                            with gr.Column():
         
     | 
| 104 | 
         
            +
                                input_img = gr.Image(type="numpy", label="Input Image")
         
     | 
| 105 | 
         
            +
                                model_type = gr.Dropdown(choices=list(config_dict.keys()),
         
     | 
| 106 | 
         
            +
                                                         value=list(config_dict.keys())[0],
         
     | 
| 107 | 
         
            +
                                                         label='Model',
         
     | 
| 108 | 
         
            +
                                                         multiselect=False)
         
     | 
| 109 | 
         
            +
                                with gr.Row():
         
     | 
| 110 | 
         
            +
                                    clear_btn = gr.Button(value="Clear")
         
     | 
| 111 | 
         
            +
                                    submit_btn = gr.Button(value="Submit")
         
     | 
| 112 | 
         
            +
                            output_img = gr.Image(type="numpy", label="Output")
         
     | 
| 113 | 
         
            +
                        gr.Examples(
         
     | 
| 114 | 
         
            +
                            examples=[["./assets/img1.jpg", "swin-l-hdetr_sam-vit-b"],
         
     | 
| 115 | 
         
            +
                                      ["./assets/img2.jpg", "swin-l-hdetr_sam-vit-l"],
         
     | 
| 116 | 
         
            +
                                      ["./assets/img3.jpg", "swin-l-hdetr_sam-vit-l"],
         
     | 
| 117 | 
         
            +
                                      ["./assets/img4.jpg", "focalnet-l-dino_sam-vit-b"]],
         
     | 
| 118 | 
         
            +
                            inputs=[input_img, model_type],
         
     | 
| 119 | 
         
            +
                            outputs=output_img,
         
     | 
| 120 | 
         
            +
                            fn=inference
         
     | 
| 121 | 
         
            +
                        )
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    submit_btn.click(inference,
         
     | 
| 124 | 
         
            +
                                     inputs=[input_img, model_type],
         
     | 
| 125 | 
         
            +
                                     outputs=output_img)
         
     | 
| 126 | 
         
            +
                    clear_btn.click(lambda: [None, None], None, [input_img, output_img], queue=False)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                demo.queue()
         
     | 
| 129 | 
         
            +
                demo.launch(share=True)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 133 | 
         
            +
                main()
         
     | 
    	
        assets/img1.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/img2.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/img3.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/img4.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        flagged/Input/tmpaytsmk0e.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        flagged/Output/tmpgs59m7u_.png
    ADDED
    
    
											 
									 | 
									
								
    	
        flagged/log.csv
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Input,Output,flag,username,timestamp
         
     | 
| 2 | 
         
            +
            C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Input\tmpaytsmk0e.jpg,C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Output\tmpgs59m7u_.png,,,2023-04-10 20:52:40.908980
         
     | 
    	
        mmdet/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import mmcv
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from .version import __version__, short_version
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def digit_version(version_str):
         
     | 
| 8 | 
         
            +
                digit_version = []
         
     | 
| 9 | 
         
            +
                for x in version_str.split('.'):
         
     | 
| 10 | 
         
            +
                    if x.isdigit():
         
     | 
| 11 | 
         
            +
                        digit_version.append(int(x))
         
     | 
| 12 | 
         
            +
                    elif x.find('rc') != -1:
         
     | 
| 13 | 
         
            +
                        patch_version = x.split('rc')
         
     | 
| 14 | 
         
            +
                        digit_version.append(int(patch_version[0]) - 1)
         
     | 
| 15 | 
         
            +
                        digit_version.append(int(patch_version[1]))
         
     | 
| 16 | 
         
            +
                return digit_version
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            mmcv_minimum_version = '1.3.17'
         
     | 
| 20 | 
         
            +
            mmcv_maximum_version = '1.8.0'
         
     | 
| 21 | 
         
            +
            mmcv_version = digit_version(mmcv.__version__)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            assert (mmcv_version >= digit_version(mmcv_minimum_version)
         
     | 
| 25 | 
         
            +
                    and mmcv_version <= digit_version(mmcv_maximum_version)), \
         
     | 
| 26 | 
         
            +
                f'MMCV=={mmcv.__version__} is used but incompatible. ' \
         
     | 
| 27 | 
         
            +
                f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            __all__ = ['__version__', 'short_version']
         
     | 
    	
        mmdet/apis/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from .inference import (async_inference_detector, inference_detector,
         
     | 
| 3 | 
         
            +
                                    init_detector, show_result_pyplot)
         
     | 
| 4 | 
         
            +
            from .test import multi_gpu_test, single_gpu_test
         
     | 
| 5 | 
         
            +
            from .train import (get_root_logger, init_random_seed, set_random_seed,
         
     | 
| 6 | 
         
            +
                                train_detector)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            __all__ = [
         
     | 
| 9 | 
         
            +
                'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
         
     | 
| 10 | 
         
            +
                'async_inference_detector', 'inference_detector', 'show_result_pyplot',
         
     | 
| 11 | 
         
            +
                'multi_gpu_test', 'single_gpu_test', 'init_random_seed'
         
     | 
| 12 | 
         
            +
            ]
         
     | 
    	
        mmdet/apis/inference.py
    ADDED
    
    | 
         @@ -0,0 +1,258 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import mmcv
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from mmcv.ops import RoIPool
         
     | 
| 9 | 
         
            +
            from mmcv.parallel import collate, scatter
         
     | 
| 10 | 
         
            +
            from mmcv.runner import load_checkpoint
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from mmdet.core import get_classes
         
     | 
| 13 | 
         
            +
            from mmdet.datasets import replace_ImageToTensor
         
     | 
| 14 | 
         
            +
            from mmdet.datasets.pipelines import Compose
         
     | 
| 15 | 
         
            +
            from mmdet.models import build_detector
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
         
     | 
| 19 | 
         
            +
                """Initialize a detector from config file.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    config (str, :obj:`Path`, or :obj:`mmcv.Config`): Config file path,
         
     | 
| 23 | 
         
            +
                        :obj:`Path`, or the config object.
         
     | 
| 24 | 
         
            +
                    checkpoint (str, optional): Checkpoint path. If left as None, the model
         
     | 
| 25 | 
         
            +
                        will not load any weights.
         
     | 
| 26 | 
         
            +
                    cfg_options (dict): Options to override some settings in the used
         
     | 
| 27 | 
         
            +
                        config.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                Returns:
         
     | 
| 30 | 
         
            +
                    nn.Module: The constructed detector.
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                if isinstance(config, (str, Path)):
         
     | 
| 33 | 
         
            +
                    config = mmcv.Config.fromfile(config)
         
     | 
| 34 | 
         
            +
                elif not isinstance(config, mmcv.Config):
         
     | 
| 35 | 
         
            +
                    raise TypeError('config must be a filename or Config object, '
         
     | 
| 36 | 
         
            +
                                    f'but got {type(config)}')
         
     | 
| 37 | 
         
            +
                if cfg_options is not None:
         
     | 
| 38 | 
         
            +
                    config.merge_from_dict(cfg_options)
         
     | 
| 39 | 
         
            +
                if 'pretrained' in config.model:
         
     | 
| 40 | 
         
            +
                    config.model.pretrained = None
         
     | 
| 41 | 
         
            +
                elif (config.model.get('backbone', None) is not None
         
     | 
| 42 | 
         
            +
                      and 'init_cfg' in config.model.backbone):
         
     | 
| 43 | 
         
            +
                    config.model.backbone.init_cfg = None
         
     | 
| 44 | 
         
            +
                config.model.train_cfg = None
         
     | 
| 45 | 
         
            +
                model = build_detector(config.model, test_cfg=config.get('test_cfg'))
         
     | 
| 46 | 
         
            +
                if checkpoint is not None:
         
     | 
| 47 | 
         
            +
                    checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
         
     | 
| 48 | 
         
            +
                    if 'CLASSES' in checkpoint.get('meta', {}):
         
     | 
| 49 | 
         
            +
                        model.CLASSES = checkpoint['meta']['CLASSES']
         
     | 
| 50 | 
         
            +
                    else:
         
     | 
| 51 | 
         
            +
                        warnings.simplefilter('once')
         
     | 
| 52 | 
         
            +
                        warnings.warn('Class names are not saved in the checkpoint\'s '
         
     | 
| 53 | 
         
            +
                                      'meta data, use COCO classes by default.')
         
     | 
| 54 | 
         
            +
                        model.CLASSES = get_classes('coco')
         
     | 
| 55 | 
         
            +
                model.cfg = config  # save the config in the model for convenience
         
     | 
| 56 | 
         
            +
                model.to(device)
         
     | 
| 57 | 
         
            +
                model.eval()
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                if device == 'npu':
         
     | 
| 60 | 
         
            +
                    from mmcv.device.npu import NPUDataParallel
         
     | 
| 61 | 
         
            +
                    model = NPUDataParallel(model)
         
     | 
| 62 | 
         
            +
                    model.cfg = config
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                return model
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            class LoadImage:
         
     | 
| 68 | 
         
            +
                """Deprecated.
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                A simple pipeline to load image.
         
     | 
| 71 | 
         
            +
                """
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def __call__(self, results):
         
     | 
| 74 | 
         
            +
                    """Call function to load images into results.
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    Args:
         
     | 
| 77 | 
         
            +
                        results (dict): A result dict contains the file name
         
     | 
| 78 | 
         
            +
                            of the image to be read.
         
     | 
| 79 | 
         
            +
                    Returns:
         
     | 
| 80 | 
         
            +
                        dict: ``results`` will be returned containing loaded image.
         
     | 
| 81 | 
         
            +
                    """
         
     | 
| 82 | 
         
            +
                    warnings.simplefilter('once')
         
     | 
| 83 | 
         
            +
                    warnings.warn('`LoadImage` is deprecated and will be removed in '
         
     | 
| 84 | 
         
            +
                                  'future releases. You may use `LoadImageFromWebcam` '
         
     | 
| 85 | 
         
            +
                                  'from `mmdet.datasets.pipelines.` instead.')
         
     | 
| 86 | 
         
            +
                    if isinstance(results['img'], str):
         
     | 
| 87 | 
         
            +
                        results['filename'] = results['img']
         
     | 
| 88 | 
         
            +
                        results['ori_filename'] = results['img']
         
     | 
| 89 | 
         
            +
                    else:
         
     | 
| 90 | 
         
            +
                        results['filename'] = None
         
     | 
| 91 | 
         
            +
                        results['ori_filename'] = None
         
     | 
| 92 | 
         
            +
                    img = mmcv.imread(results['img'])
         
     | 
| 93 | 
         
            +
                    results['img'] = img
         
     | 
| 94 | 
         
            +
                    results['img_fields'] = ['img']
         
     | 
| 95 | 
         
            +
                    results['img_shape'] = img.shape
         
     | 
| 96 | 
         
            +
                    results['ori_shape'] = img.shape
         
     | 
| 97 | 
         
            +
                    return results
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def inference_detector(model, imgs):
         
     | 
| 101 | 
         
            +
                """Inference image(s) with the detector.
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                Args:
         
     | 
| 104 | 
         
            +
                    model (nn.Module): The loaded detector.
         
     | 
| 105 | 
         
            +
                    imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
         
     | 
| 106 | 
         
            +
                       Either image files or loaded images.
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                Returns:
         
     | 
| 109 | 
         
            +
                    If imgs is a list or tuple, the same length list type results
         
     | 
| 110 | 
         
            +
                    will be returned, otherwise return the detection results directly.
         
     | 
| 111 | 
         
            +
                """
         
     | 
| 112 | 
         
            +
                ori_img = imgs
         
     | 
| 113 | 
         
            +
                if isinstance(imgs, (list, tuple)):
         
     | 
| 114 | 
         
            +
                    is_batch = True
         
     | 
| 115 | 
         
            +
                else:
         
     | 
| 116 | 
         
            +
                    imgs = [imgs]
         
     | 
| 117 | 
         
            +
                    is_batch = False
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                cfg = model.cfg
         
     | 
| 120 | 
         
            +
                device = next(model.parameters()).device  # model device
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                if isinstance(imgs[0], np.ndarray):
         
     | 
| 123 | 
         
            +
                    cfg = cfg.copy()
         
     | 
| 124 | 
         
            +
                    # set loading pipeline type
         
     | 
| 125 | 
         
            +
                    cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
         
     | 
| 128 | 
         
            +
                test_pipeline = Compose(cfg.data.test.pipeline)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                datas = []
         
     | 
| 131 | 
         
            +
                for img in imgs:
         
     | 
| 132 | 
         
            +
                    # prepare data
         
     | 
| 133 | 
         
            +
                    if isinstance(img, np.ndarray):
         
     | 
| 134 | 
         
            +
                        # directly add img
         
     | 
| 135 | 
         
            +
                        data = dict(img=img)
         
     | 
| 136 | 
         
            +
                    else:
         
     | 
| 137 | 
         
            +
                        # add information into dict
         
     | 
| 138 | 
         
            +
                        data = dict(img_info=dict(filename=img), img_prefix=None)
         
     | 
| 139 | 
         
            +
                    # build the data pipeline
         
     | 
| 140 | 
         
            +
                    data = test_pipeline(data)
         
     | 
| 141 | 
         
            +
                    datas.append(data)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                data = collate(datas, samples_per_gpu=len(imgs))
         
     | 
| 144 | 
         
            +
                # just get the actual data from DataContainer
         
     | 
| 145 | 
         
            +
                data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
         
     | 
| 146 | 
         
            +
                data['img'] = [img.data[0] for img in data['img']]
         
     | 
| 147 | 
         
            +
                if next(model.parameters()).is_cuda:
         
     | 
| 148 | 
         
            +
                    # scatter to specified GPU
         
     | 
| 149 | 
         
            +
                    data = scatter(data, [device])[0]
         
     | 
| 150 | 
         
            +
                else:
         
     | 
| 151 | 
         
            +
                    for m in model.modules():
         
     | 
| 152 | 
         
            +
                        assert not isinstance(
         
     | 
| 153 | 
         
            +
                            m, RoIPool
         
     | 
| 154 | 
         
            +
                        ), 'CPU inference with RoIPool is not supported currently.'
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                # forward the model
         
     | 
| 157 | 
         
            +
                with torch.no_grad():
         
     | 
| 158 | 
         
            +
                    results = model(return_loss=False, rescale=True, **data, ori_img=ori_img)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                if not is_batch:
         
     | 
| 161 | 
         
            +
                    return results[0]
         
     | 
| 162 | 
         
            +
                else:
         
     | 
| 163 | 
         
            +
                    return results
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            async def async_inference_detector(model, imgs):
         
     | 
| 167 | 
         
            +
                """Async inference image(s) with the detector.
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                Args:
         
     | 
| 170 | 
         
            +
                    model (nn.Module): The loaded detector.
         
     | 
| 171 | 
         
            +
                    img (str | ndarray): Either image files or loaded images.
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                Returns:
         
     | 
| 174 | 
         
            +
                    Awaitable detection results.
         
     | 
| 175 | 
         
            +
                """
         
     | 
| 176 | 
         
            +
                if not isinstance(imgs, (list, tuple)):
         
     | 
| 177 | 
         
            +
                    imgs = [imgs]
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                cfg = model.cfg
         
     | 
| 180 | 
         
            +
                device = next(model.parameters()).device  # model device
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                if isinstance(imgs[0], np.ndarray):
         
     | 
| 183 | 
         
            +
                    cfg = cfg.copy()
         
     | 
| 184 | 
         
            +
                    # set loading pipeline type
         
     | 
| 185 | 
         
            +
                    cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
         
     | 
| 188 | 
         
            +
                test_pipeline = Compose(cfg.data.test.pipeline)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                datas = []
         
     | 
| 191 | 
         
            +
                for img in imgs:
         
     | 
| 192 | 
         
            +
                    # prepare data
         
     | 
| 193 | 
         
            +
                    if isinstance(img, np.ndarray):
         
     | 
| 194 | 
         
            +
                        # directly add img
         
     | 
| 195 | 
         
            +
                        data = dict(img=img)
         
     | 
| 196 | 
         
            +
                    else:
         
     | 
| 197 | 
         
            +
                        # add information into dict
         
     | 
| 198 | 
         
            +
                        data = dict(img_info=dict(filename=img), img_prefix=None)
         
     | 
| 199 | 
         
            +
                    # build the data pipeline
         
     | 
| 200 | 
         
            +
                    data = test_pipeline(data)
         
     | 
| 201 | 
         
            +
                    datas.append(data)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                data = collate(datas, samples_per_gpu=len(imgs))
         
     | 
| 204 | 
         
            +
                # just get the actual data from DataContainer
         
     | 
| 205 | 
         
            +
                data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
         
     | 
| 206 | 
         
            +
                data['img'] = [img.data[0] for img in data['img']]
         
     | 
| 207 | 
         
            +
                if next(model.parameters()).is_cuda:
         
     | 
| 208 | 
         
            +
                    # scatter to specified GPU
         
     | 
| 209 | 
         
            +
                    data = scatter(data, [device])[0]
         
     | 
| 210 | 
         
            +
                else:
         
     | 
| 211 | 
         
            +
                    for m in model.modules():
         
     | 
| 212 | 
         
            +
                        assert not isinstance(
         
     | 
| 213 | 
         
            +
                            m, RoIPool
         
     | 
| 214 | 
         
            +
                        ), 'CPU inference with RoIPool is not supported currently.'
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                # We don't restore `torch.is_grad_enabled()` value during concurrent
         
     | 
| 217 | 
         
            +
                # inference since execution can overlap
         
     | 
| 218 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 219 | 
         
            +
                results = await model.aforward_test(rescale=True, **data)
         
     | 
| 220 | 
         
            +
                return results
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            def show_result_pyplot(model,
         
     | 
| 224 | 
         
            +
                                   img,
         
     | 
| 225 | 
         
            +
                                   result,
         
     | 
| 226 | 
         
            +
                                   score_thr=0.3,
         
     | 
| 227 | 
         
            +
                                   title='result',
         
     | 
| 228 | 
         
            +
                                   wait_time=0,
         
     | 
| 229 | 
         
            +
                                   palette=None,
         
     | 
| 230 | 
         
            +
                                   out_file=None):
         
     | 
| 231 | 
         
            +
                """Visualize the detection results on the image.
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                Args:
         
     | 
| 234 | 
         
            +
                    model (nn.Module): The loaded detector.
         
     | 
| 235 | 
         
            +
                    img (str or np.ndarray): Image filename or loaded image.
         
     | 
| 236 | 
         
            +
                    result (tuple[list] or list): The detection result, can be either
         
     | 
| 237 | 
         
            +
                        (bbox, segm) or just bbox.
         
     | 
| 238 | 
         
            +
                    score_thr (float): The threshold to visualize the bboxes and masks.
         
     | 
| 239 | 
         
            +
                    title (str): Title of the pyplot figure.
         
     | 
| 240 | 
         
            +
                    wait_time (float): Value of waitKey param. Default: 0.
         
     | 
| 241 | 
         
            +
                    palette (str or tuple(int) or :obj:`Color`): Color.
         
     | 
| 242 | 
         
            +
                        The tuple of color should be in BGR order.
         
     | 
| 243 | 
         
            +
                    out_file (str or None): The path to write the image.
         
     | 
| 244 | 
         
            +
                        Default: None.
         
     | 
| 245 | 
         
            +
                """
         
     | 
| 246 | 
         
            +
                if hasattr(model, 'module'):
         
     | 
| 247 | 
         
            +
                    model = model.module
         
     | 
| 248 | 
         
            +
                model.show_result(
         
     | 
| 249 | 
         
            +
                    img,
         
     | 
| 250 | 
         
            +
                    result,
         
     | 
| 251 | 
         
            +
                    score_thr=score_thr,
         
     | 
| 252 | 
         
            +
                    show=True,
         
     | 
| 253 | 
         
            +
                    wait_time=wait_time,
         
     | 
| 254 | 
         
            +
                    win_name=title,
         
     | 
| 255 | 
         
            +
                    bbox_color=palette,
         
     | 
| 256 | 
         
            +
                    text_color=(200, 200, 200),
         
     | 
| 257 | 
         
            +
                    mask_color=palette,
         
     | 
| 258 | 
         
            +
                    out_file=out_file)
         
     | 
    	
        mmdet/apis/test.py
    ADDED
    
    | 
         @@ -0,0 +1,209 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import os.path as osp
         
     | 
| 3 | 
         
            +
            import pickle
         
     | 
| 4 | 
         
            +
            import shutil
         
     | 
| 5 | 
         
            +
            import tempfile
         
     | 
| 6 | 
         
            +
            import time
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import mmcv
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import torch.distributed as dist
         
     | 
| 11 | 
         
            +
            from mmcv.image import tensor2imgs
         
     | 
| 12 | 
         
            +
            from mmcv.runner import get_dist_info
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from mmdet.core import encode_mask_results
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def single_gpu_test(model,
         
     | 
| 18 | 
         
            +
                                data_loader,
         
     | 
| 19 | 
         
            +
                                show=False,
         
     | 
| 20 | 
         
            +
                                out_dir=None,
         
     | 
| 21 | 
         
            +
                                show_score_thr=0.3):
         
     | 
| 22 | 
         
            +
                model.eval()
         
     | 
| 23 | 
         
            +
                results = []
         
     | 
| 24 | 
         
            +
                dataset = data_loader.dataset
         
     | 
| 25 | 
         
            +
                PALETTE = getattr(dataset, 'PALETTE', None)
         
     | 
| 26 | 
         
            +
                prog_bar = mmcv.ProgressBar(len(dataset))
         
     | 
| 27 | 
         
            +
                for i, data in enumerate(data_loader):
         
     | 
| 28 | 
         
            +
                    with torch.no_grad():
         
     | 
| 29 | 
         
            +
                        result = model(return_loss=False, rescale=True, **data)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    batch_size = len(result)
         
     | 
| 32 | 
         
            +
                    if show or out_dir:
         
     | 
| 33 | 
         
            +
                        if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
         
     | 
| 34 | 
         
            +
                            img_tensor = data['img'][0]
         
     | 
| 35 | 
         
            +
                        else:
         
     | 
| 36 | 
         
            +
                            img_tensor = data['img'][0].data[0]
         
     | 
| 37 | 
         
            +
                        img_metas = data['img_metas'][0].data[0]
         
     | 
| 38 | 
         
            +
                        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
         
     | 
| 39 | 
         
            +
                        assert len(imgs) == len(img_metas)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                        for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
         
     | 
| 42 | 
         
            +
                            h, w, _ = img_meta['img_shape']
         
     | 
| 43 | 
         
            +
                            img_show = img[:h, :w, :]
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                            ori_h, ori_w = img_meta['ori_shape'][:-1]
         
     | 
| 46 | 
         
            +
                            img_show = mmcv.imresize(img_show, (ori_w, ori_h))
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                            if out_dir:
         
     | 
| 49 | 
         
            +
                                out_file = osp.join(out_dir, img_meta['ori_filename'])
         
     | 
| 50 | 
         
            +
                            else:
         
     | 
| 51 | 
         
            +
                                out_file = None
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                            model.module.show_result(
         
     | 
| 54 | 
         
            +
                                img_show,
         
     | 
| 55 | 
         
            +
                                result[i],
         
     | 
| 56 | 
         
            +
                                bbox_color=PALETTE,
         
     | 
| 57 | 
         
            +
                                text_color=PALETTE,
         
     | 
| 58 | 
         
            +
                                mask_color=PALETTE,
         
     | 
| 59 | 
         
            +
                                show=show,
         
     | 
| 60 | 
         
            +
                                out_file=out_file,
         
     | 
| 61 | 
         
            +
                                score_thr=show_score_thr)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    # encode mask results
         
     | 
| 64 | 
         
            +
                    if isinstance(result[0], tuple):
         
     | 
| 65 | 
         
            +
                        result = [(bbox_results, encode_mask_results(mask_results))
         
     | 
| 66 | 
         
            +
                                  for bbox_results, mask_results in result]
         
     | 
| 67 | 
         
            +
                    # This logic is only used in panoptic segmentation test.
         
     | 
| 68 | 
         
            +
                    elif isinstance(result[0], dict) and 'ins_results' in result[0]:
         
     | 
| 69 | 
         
            +
                        for j in range(len(result)):
         
     | 
| 70 | 
         
            +
                            bbox_results, mask_results = result[j]['ins_results']
         
     | 
| 71 | 
         
            +
                            result[j]['ins_results'] = (bbox_results,
         
     | 
| 72 | 
         
            +
                                                        encode_mask_results(mask_results))
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    results.extend(result)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    for _ in range(batch_size):
         
     | 
| 77 | 
         
            +
                        prog_bar.update()
         
     | 
| 78 | 
         
            +
                return results
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
         
     | 
| 82 | 
         
            +
                """Test model with multiple gpus.
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                This method tests model with multiple gpus and collects the results
         
     | 
| 85 | 
         
            +
                under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
         
     | 
| 86 | 
         
            +
                it encodes results to gpu tensors and use gpu communication for results
         
     | 
| 87 | 
         
            +
                collection. On cpu mode it saves the results on different gpus to 'tmpdir'
         
     | 
| 88 | 
         
            +
                and collects them by the rank 0 worker.
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                Args:
         
     | 
| 91 | 
         
            +
                    model (nn.Module): Model to be tested.
         
     | 
| 92 | 
         
            +
                    data_loader (nn.Dataloader): Pytorch data loader.
         
     | 
| 93 | 
         
            +
                    tmpdir (str): Path of directory to save the temporary results from
         
     | 
| 94 | 
         
            +
                        different gpus under cpu mode.
         
     | 
| 95 | 
         
            +
                    gpu_collect (bool): Option to use either gpu or cpu to collect results.
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                Returns:
         
     | 
| 98 | 
         
            +
                    list: The prediction results.
         
     | 
| 99 | 
         
            +
                """
         
     | 
| 100 | 
         
            +
                model.eval()
         
     | 
| 101 | 
         
            +
                results = []
         
     | 
| 102 | 
         
            +
                dataset = data_loader.dataset
         
     | 
| 103 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 104 | 
         
            +
                if rank == 0:
         
     | 
| 105 | 
         
            +
                    prog_bar = mmcv.ProgressBar(len(dataset))
         
     | 
| 106 | 
         
            +
                time.sleep(2)  # This line can prevent deadlock problem in some cases.
         
     | 
| 107 | 
         
            +
                for i, data in enumerate(data_loader):
         
     | 
| 108 | 
         
            +
                    with torch.no_grad():
         
     | 
| 109 | 
         
            +
                        result = model(return_loss=False, rescale=True, **data)
         
     | 
| 110 | 
         
            +
                        # encode mask results
         
     | 
| 111 | 
         
            +
                        if isinstance(result[0], tuple):
         
     | 
| 112 | 
         
            +
                            result = [(bbox_results, encode_mask_results(mask_results))
         
     | 
| 113 | 
         
            +
                                      for bbox_results, mask_results in result]
         
     | 
| 114 | 
         
            +
                        # This logic is only used in panoptic segmentation test.
         
     | 
| 115 | 
         
            +
                        elif isinstance(result[0], dict) and 'ins_results' in result[0]:
         
     | 
| 116 | 
         
            +
                            for j in range(len(result)):
         
     | 
| 117 | 
         
            +
                                bbox_results, mask_results = result[j]['ins_results']
         
     | 
| 118 | 
         
            +
                                result[j]['ins_results'] = (
         
     | 
| 119 | 
         
            +
                                    bbox_results, encode_mask_results(mask_results))
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    results.extend(result)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if rank == 0:
         
     | 
| 124 | 
         
            +
                        batch_size = len(result)
         
     | 
| 125 | 
         
            +
                        for _ in range(batch_size * world_size):
         
     | 
| 126 | 
         
            +
                            prog_bar.update()
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                # collect results from all ranks
         
     | 
| 129 | 
         
            +
                if gpu_collect:
         
     | 
| 130 | 
         
            +
                    results = collect_results_gpu(results, len(dataset))
         
     | 
| 131 | 
         
            +
                else:
         
     | 
| 132 | 
         
            +
                    results = collect_results_cpu(results, len(dataset), tmpdir)
         
     | 
| 133 | 
         
            +
                return results
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def collect_results_cpu(result_part, size, tmpdir=None):
         
     | 
| 137 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 138 | 
         
            +
                # create a tmp dir if it is not specified
         
     | 
| 139 | 
         
            +
                if tmpdir is None:
         
     | 
| 140 | 
         
            +
                    MAX_LEN = 512
         
     | 
| 141 | 
         
            +
                    # 32 is whitespace
         
     | 
| 142 | 
         
            +
                    dir_tensor = torch.full((MAX_LEN, ),
         
     | 
| 143 | 
         
            +
                                            32,
         
     | 
| 144 | 
         
            +
                                            dtype=torch.uint8,
         
     | 
| 145 | 
         
            +
                                            device='cuda')
         
     | 
| 146 | 
         
            +
                    if rank == 0:
         
     | 
| 147 | 
         
            +
                        mmcv.mkdir_or_exist('.dist_test')
         
     | 
| 148 | 
         
            +
                        tmpdir = tempfile.mkdtemp(dir='.dist_test')
         
     | 
| 149 | 
         
            +
                        tmpdir = torch.tensor(
         
     | 
| 150 | 
         
            +
                            bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
         
     | 
| 151 | 
         
            +
                        dir_tensor[:len(tmpdir)] = tmpdir
         
     | 
| 152 | 
         
            +
                    dist.broadcast(dir_tensor, 0)
         
     | 
| 153 | 
         
            +
                    tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
         
     | 
| 154 | 
         
            +
                else:
         
     | 
| 155 | 
         
            +
                    mmcv.mkdir_or_exist(tmpdir)
         
     | 
| 156 | 
         
            +
                # dump the part result to the dir
         
     | 
| 157 | 
         
            +
                mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
         
     | 
| 158 | 
         
            +
                dist.barrier()
         
     | 
| 159 | 
         
            +
                # collect all parts
         
     | 
| 160 | 
         
            +
                if rank != 0:
         
     | 
| 161 | 
         
            +
                    return None
         
     | 
| 162 | 
         
            +
                else:
         
     | 
| 163 | 
         
            +
                    # load results of all parts from tmp dir
         
     | 
| 164 | 
         
            +
                    part_list = []
         
     | 
| 165 | 
         
            +
                    for i in range(world_size):
         
     | 
| 166 | 
         
            +
                        part_file = osp.join(tmpdir, f'part_{i}.pkl')
         
     | 
| 167 | 
         
            +
                        part_list.append(mmcv.load(part_file))
         
     | 
| 168 | 
         
            +
                    # sort the results
         
     | 
| 169 | 
         
            +
                    ordered_results = []
         
     | 
| 170 | 
         
            +
                    for res in zip(*part_list):
         
     | 
| 171 | 
         
            +
                        ordered_results.extend(list(res))
         
     | 
| 172 | 
         
            +
                    # the dataloader may pad some samples
         
     | 
| 173 | 
         
            +
                    ordered_results = ordered_results[:size]
         
     | 
| 174 | 
         
            +
                    # remove tmp dir
         
     | 
| 175 | 
         
            +
                    shutil.rmtree(tmpdir)
         
     | 
| 176 | 
         
            +
                    return ordered_results
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            def collect_results_gpu(result_part, size):
         
     | 
| 180 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 181 | 
         
            +
                # dump result part to tensor with pickle
         
     | 
| 182 | 
         
            +
                part_tensor = torch.tensor(
         
     | 
| 183 | 
         
            +
                    bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
         
     | 
| 184 | 
         
            +
                # gather all result part tensor shape
         
     | 
| 185 | 
         
            +
                shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
         
     | 
| 186 | 
         
            +
                shape_list = [shape_tensor.clone() for _ in range(world_size)]
         
     | 
| 187 | 
         
            +
                dist.all_gather(shape_list, shape_tensor)
         
     | 
| 188 | 
         
            +
                # padding result part tensor to max length
         
     | 
| 189 | 
         
            +
                shape_max = torch.tensor(shape_list).max()
         
     | 
| 190 | 
         
            +
                part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
         
     | 
| 191 | 
         
            +
                part_send[:shape_tensor[0]] = part_tensor
         
     | 
| 192 | 
         
            +
                part_recv_list = [
         
     | 
| 193 | 
         
            +
                    part_tensor.new_zeros(shape_max) for _ in range(world_size)
         
     | 
| 194 | 
         
            +
                ]
         
     | 
| 195 | 
         
            +
                # gather all result part
         
     | 
| 196 | 
         
            +
                dist.all_gather(part_recv_list, part_send)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                if rank == 0:
         
     | 
| 199 | 
         
            +
                    part_list = []
         
     | 
| 200 | 
         
            +
                    for recv, shape in zip(part_recv_list, shape_list):
         
     | 
| 201 | 
         
            +
                        part_list.append(
         
     | 
| 202 | 
         
            +
                            pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
         
     | 
| 203 | 
         
            +
                    # sort the results
         
     | 
| 204 | 
         
            +
                    ordered_results = []
         
     | 
| 205 | 
         
            +
                    for res in zip(*part_list):
         
     | 
| 206 | 
         
            +
                        ordered_results.extend(list(res))
         
     | 
| 207 | 
         
            +
                    # the dataloader may pad some samples
         
     | 
| 208 | 
         
            +
                    ordered_results = ordered_results[:size]
         
     | 
| 209 | 
         
            +
                    return ordered_results
         
     | 
    	
        mmdet/apis/train.py
    ADDED
    
    | 
         @@ -0,0 +1,246 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.distributed as dist
         
     | 
| 8 | 
         
            +
            from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
         
     | 
| 9 | 
         
            +
                                     Fp16OptimizerHook, OptimizerHook, build_runner,
         
     | 
| 10 | 
         
            +
                                     get_dist_info)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from mmdet.core import DistEvalHook, EvalHook, build_optimizer
         
     | 
| 13 | 
         
            +
            from mmdet.datasets import (build_dataloader, build_dataset,
         
     | 
| 14 | 
         
            +
                                        replace_ImageToTensor)
         
     | 
| 15 | 
         
            +
            from mmdet.utils import (build_ddp, build_dp, compat_cfg,
         
     | 
| 16 | 
         
            +
                                     find_latest_checkpoint, get_root_logger)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def init_random_seed(seed=None, device='cuda'):
         
     | 
| 20 | 
         
            +
                """Initialize random seed.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                If the seed is not set, the seed will be automatically randomized,
         
     | 
| 23 | 
         
            +
                and then broadcast to all processes to prevent some potential bugs.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Args:
         
     | 
| 26 | 
         
            +
                    seed (int, Optional): The seed. Default to None.
         
     | 
| 27 | 
         
            +
                    device (str): The device where the seed will be put on.
         
     | 
| 28 | 
         
            +
                        Default to 'cuda'.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                Returns:
         
     | 
| 31 | 
         
            +
                    int: Seed to be used.
         
     | 
| 32 | 
         
            +
                """
         
     | 
| 33 | 
         
            +
                if seed is not None:
         
     | 
| 34 | 
         
            +
                    return seed
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                # Make sure all ranks share the same random seed to prevent
         
     | 
| 37 | 
         
            +
                # some potential bugs. Please refer to
         
     | 
| 38 | 
         
            +
                # https://github.com/open-mmlab/mmdetection/issues/6339
         
     | 
| 39 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 40 | 
         
            +
                seed = np.random.randint(2**31)
         
     | 
| 41 | 
         
            +
                if world_size == 1:
         
     | 
| 42 | 
         
            +
                    return seed
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                if rank == 0:
         
     | 
| 45 | 
         
            +
                    random_num = torch.tensor(seed, dtype=torch.int32, device=device)
         
     | 
| 46 | 
         
            +
                else:
         
     | 
| 47 | 
         
            +
                    random_num = torch.tensor(0, dtype=torch.int32, device=device)
         
     | 
| 48 | 
         
            +
                dist.broadcast(random_num, src=0)
         
     | 
| 49 | 
         
            +
                return random_num.item()
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def set_random_seed(seed, deterministic=False):
         
     | 
| 53 | 
         
            +
                """Set random seed.
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                Args:
         
     | 
| 56 | 
         
            +
                    seed (int): Seed to be used.
         
     | 
| 57 | 
         
            +
                    deterministic (bool): Whether to set the deterministic option for
         
     | 
| 58 | 
         
            +
                        CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
         
     | 
| 59 | 
         
            +
                        to True and `torch.backends.cudnn.benchmark` to False.
         
     | 
| 60 | 
         
            +
                        Default: False.
         
     | 
| 61 | 
         
            +
                """
         
     | 
| 62 | 
         
            +
                random.seed(seed)
         
     | 
| 63 | 
         
            +
                np.random.seed(seed)
         
     | 
| 64 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 65 | 
         
            +
                torch.cuda.manual_seed_all(seed)
         
     | 
| 66 | 
         
            +
                if deterministic:
         
     | 
| 67 | 
         
            +
                    torch.backends.cudnn.deterministic = True
         
     | 
| 68 | 
         
            +
                    torch.backends.cudnn.benchmark = False
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def auto_scale_lr(cfg, distributed, logger):
         
     | 
| 72 | 
         
            +
                """Automatically scaling LR according to GPU number and sample per GPU.
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                Args:
         
     | 
| 75 | 
         
            +
                    cfg (config): Training config.
         
     | 
| 76 | 
         
            +
                    distributed (bool): Using distributed or not.
         
     | 
| 77 | 
         
            +
                    logger (logging.Logger): Logger.
         
     | 
| 78 | 
         
            +
                """
         
     | 
| 79 | 
         
            +
                # Get flag from config
         
     | 
| 80 | 
         
            +
                if ('auto_scale_lr' not in cfg) or \
         
     | 
| 81 | 
         
            +
                        (not cfg.auto_scale_lr.get('enable', False)):
         
     | 
| 82 | 
         
            +
                    logger.info('Automatic scaling of learning rate (LR)'
         
     | 
| 83 | 
         
            +
                                ' has been disabled.')
         
     | 
| 84 | 
         
            +
                    return
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                # Get base batch size from config
         
     | 
| 87 | 
         
            +
                base_batch_size = cfg.auto_scale_lr.get('base_batch_size', None)
         
     | 
| 88 | 
         
            +
                if base_batch_size is None:
         
     | 
| 89 | 
         
            +
                    return
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                # Get gpu number
         
     | 
| 92 | 
         
            +
                if distributed:
         
     | 
| 93 | 
         
            +
                    _, world_size = get_dist_info()
         
     | 
| 94 | 
         
            +
                    num_gpus = len(range(world_size))
         
     | 
| 95 | 
         
            +
                else:
         
     | 
| 96 | 
         
            +
                    num_gpus = len(cfg.gpu_ids)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                # calculate the batch size
         
     | 
| 99 | 
         
            +
                samples_per_gpu = cfg.data.train_dataloader.samples_per_gpu
         
     | 
| 100 | 
         
            +
                batch_size = num_gpus * samples_per_gpu
         
     | 
| 101 | 
         
            +
                logger.info(f'Training with {num_gpus} GPU(s) with {samples_per_gpu} '
         
     | 
| 102 | 
         
            +
                            f'samples per GPU. The total batch size is {batch_size}.')
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                if batch_size != base_batch_size:
         
     | 
| 105 | 
         
            +
                    # scale LR with
         
     | 
| 106 | 
         
            +
                    # [linear scaling rule](https://arxiv.org/abs/1706.02677)
         
     | 
| 107 | 
         
            +
                    scaled_lr = (batch_size / base_batch_size) * cfg.optimizer.lr
         
     | 
| 108 | 
         
            +
                    logger.info('LR has been automatically scaled '
         
     | 
| 109 | 
         
            +
                                f'from {cfg.optimizer.lr} to {scaled_lr}')
         
     | 
| 110 | 
         
            +
                    cfg.optimizer.lr = scaled_lr
         
     | 
| 111 | 
         
            +
                else:
         
     | 
| 112 | 
         
            +
                    logger.info('The batch size match the '
         
     | 
| 113 | 
         
            +
                                f'base batch size: {base_batch_size}, '
         
     | 
| 114 | 
         
            +
                                f'will not scaling the LR ({cfg.optimizer.lr}).')
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def train_detector(model,
         
     | 
| 118 | 
         
            +
                               dataset,
         
     | 
| 119 | 
         
            +
                               cfg,
         
     | 
| 120 | 
         
            +
                               distributed=False,
         
     | 
| 121 | 
         
            +
                               validate=False,
         
     | 
| 122 | 
         
            +
                               timestamp=None,
         
     | 
| 123 | 
         
            +
                               meta=None):
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                cfg = compat_cfg(cfg)
         
     | 
| 126 | 
         
            +
                logger = get_root_logger(log_level=cfg.log_level)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                # prepare data loaders
         
     | 
| 129 | 
         
            +
                dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
         
     | 
| 132 | 
         
            +
                    'type']
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                train_dataloader_default_args = dict(
         
     | 
| 135 | 
         
            +
                    samples_per_gpu=2,
         
     | 
| 136 | 
         
            +
                    workers_per_gpu=2,
         
     | 
| 137 | 
         
            +
                    # `num_gpus` will be ignored if distributed
         
     | 
| 138 | 
         
            +
                    num_gpus=len(cfg.gpu_ids),
         
     | 
| 139 | 
         
            +
                    dist=distributed,
         
     | 
| 140 | 
         
            +
                    seed=cfg.seed,
         
     | 
| 141 | 
         
            +
                    runner_type=runner_type,
         
     | 
| 142 | 
         
            +
                    persistent_workers=False)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                train_loader_cfg = {
         
     | 
| 145 | 
         
            +
                    **train_dataloader_default_args,
         
     | 
| 146 | 
         
            +
                    **cfg.data.get('train_dataloader', {})
         
     | 
| 147 | 
         
            +
                }
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                # put model on gpus
         
     | 
| 152 | 
         
            +
                if distributed:
         
     | 
| 153 | 
         
            +
                    find_unused_parameters = cfg.get('find_unused_parameters', False)
         
     | 
| 154 | 
         
            +
                    # Sets the `find_unused_parameters` parameter in
         
     | 
| 155 | 
         
            +
                    # torch.nn.parallel.DistributedDataParallel
         
     | 
| 156 | 
         
            +
                    model = build_ddp(
         
     | 
| 157 | 
         
            +
                        model,
         
     | 
| 158 | 
         
            +
                        cfg.device,
         
     | 
| 159 | 
         
            +
                        device_ids=[int(os.environ['LOCAL_RANK'])],
         
     | 
| 160 | 
         
            +
                        broadcast_buffers=False,
         
     | 
| 161 | 
         
            +
                        find_unused_parameters=find_unused_parameters)
         
     | 
| 162 | 
         
            +
                else:
         
     | 
| 163 | 
         
            +
                    model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                # build optimizer
         
     | 
| 166 | 
         
            +
                auto_scale_lr(cfg, distributed, logger)
         
     | 
| 167 | 
         
            +
                optimizer = build_optimizer(model, cfg.optimizer)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                runner = build_runner(
         
     | 
| 170 | 
         
            +
                    cfg.runner,
         
     | 
| 171 | 
         
            +
                    default_args=dict(
         
     | 
| 172 | 
         
            +
                        model=model,
         
     | 
| 173 | 
         
            +
                        optimizer=optimizer,
         
     | 
| 174 | 
         
            +
                        work_dir=cfg.work_dir,
         
     | 
| 175 | 
         
            +
                        logger=logger,
         
     | 
| 176 | 
         
            +
                        meta=meta))
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                # an ugly workaround to make .log and .log.json filenames the same
         
     | 
| 179 | 
         
            +
                runner.timestamp = timestamp
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                # fp16 setting
         
     | 
| 182 | 
         
            +
                fp16_cfg = cfg.get('fp16', None)
         
     | 
| 183 | 
         
            +
                if fp16_cfg is None and cfg.get('device', None) == 'npu':
         
     | 
| 184 | 
         
            +
                    fp16_cfg = dict(loss_scale='dynamic')
         
     | 
| 185 | 
         
            +
                if fp16_cfg is not None:
         
     | 
| 186 | 
         
            +
                    optimizer_config = Fp16OptimizerHook(
         
     | 
| 187 | 
         
            +
                        **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
         
     | 
| 188 | 
         
            +
                elif distributed and 'type' not in cfg.optimizer_config:
         
     | 
| 189 | 
         
            +
                    optimizer_config = OptimizerHook(**cfg.optimizer_config)
         
     | 
| 190 | 
         
            +
                else:
         
     | 
| 191 | 
         
            +
                    optimizer_config = cfg.optimizer_config
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                # register hooks
         
     | 
| 194 | 
         
            +
                runner.register_training_hooks(
         
     | 
| 195 | 
         
            +
                    cfg.lr_config,
         
     | 
| 196 | 
         
            +
                    optimizer_config,
         
     | 
| 197 | 
         
            +
                    cfg.checkpoint_config,
         
     | 
| 198 | 
         
            +
                    cfg.log_config,
         
     | 
| 199 | 
         
            +
                    cfg.get('momentum_config', None),
         
     | 
| 200 | 
         
            +
                    custom_hooks_config=cfg.get('custom_hooks', None))
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                if distributed:
         
     | 
| 203 | 
         
            +
                    if isinstance(runner, EpochBasedRunner):
         
     | 
| 204 | 
         
            +
                        runner.register_hook(DistSamplerSeedHook())
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                # register eval hooks
         
     | 
| 207 | 
         
            +
                if validate:
         
     | 
| 208 | 
         
            +
                    val_dataloader_default_args = dict(
         
     | 
| 209 | 
         
            +
                        samples_per_gpu=1,
         
     | 
| 210 | 
         
            +
                        workers_per_gpu=2,
         
     | 
| 211 | 
         
            +
                        dist=distributed,
         
     | 
| 212 | 
         
            +
                        shuffle=False,
         
     | 
| 213 | 
         
            +
                        persistent_workers=False)
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    val_dataloader_args = {
         
     | 
| 216 | 
         
            +
                        **val_dataloader_default_args,
         
     | 
| 217 | 
         
            +
                        **cfg.data.get('val_dataloader', {})
         
     | 
| 218 | 
         
            +
                    }
         
     | 
| 219 | 
         
            +
                    # Support batch_size > 1 in validation
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    if val_dataloader_args['samples_per_gpu'] > 1:
         
     | 
| 222 | 
         
            +
                        # Replace 'ImageToTensor' to 'DefaultFormatBundle'
         
     | 
| 223 | 
         
            +
                        cfg.data.val.pipeline = replace_ImageToTensor(
         
     | 
| 224 | 
         
            +
                            cfg.data.val.pipeline)
         
     | 
| 225 | 
         
            +
                    val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
         
     | 
| 228 | 
         
            +
                    eval_cfg = cfg.get('evaluation', {})
         
     | 
| 229 | 
         
            +
                    eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
         
     | 
| 230 | 
         
            +
                    eval_hook = DistEvalHook if distributed else EvalHook
         
     | 
| 231 | 
         
            +
                    # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
         
     | 
| 232 | 
         
            +
                    # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
         
     | 
| 233 | 
         
            +
                    runner.register_hook(
         
     | 
| 234 | 
         
            +
                        eval_hook(val_dataloader, **eval_cfg), priority='LOW')
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                resume_from = None
         
     | 
| 237 | 
         
            +
                if cfg.resume_from is None and cfg.get('auto_resume'):
         
     | 
| 238 | 
         
            +
                    resume_from = find_latest_checkpoint(cfg.work_dir)
         
     | 
| 239 | 
         
            +
                if resume_from is not None:
         
     | 
| 240 | 
         
            +
                    cfg.resume_from = resume_from
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                if cfg.resume_from:
         
     | 
| 243 | 
         
            +
                    runner.resume(cfg.resume_from)
         
     | 
| 244 | 
         
            +
                elif cfg.load_from:
         
     | 
| 245 | 
         
            +
                    runner.load_checkpoint(cfg.load_from)
         
     | 
| 246 | 
         
            +
                runner.run(data_loaders, cfg.workflow)
         
     | 
    	
        mmdet/core/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from .anchor import *  # noqa: F401, F403
         
     | 
| 3 | 
         
            +
            from .bbox import *  # noqa: F401, F403
         
     | 
| 4 | 
         
            +
            from .data_structures import *  # noqa: F401, F403
         
     | 
| 5 | 
         
            +
            from .evaluation import *  # noqa: F401, F403
         
     | 
| 6 | 
         
            +
            from .hook import *  # noqa: F401, F403
         
     | 
| 7 | 
         
            +
            from .mask import *  # noqa: F401, F403
         
     | 
| 8 | 
         
            +
            from .optimizers import *  # noqa: F401, F403
         
     | 
| 9 | 
         
            +
            from .post_processing import *  # noqa: F401, F403
         
     | 
| 10 | 
         
            +
            from .utils import *  # noqa: F401, F403
         
     | 
    	
        mmdet/core/anchor/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
         
     | 
| 3 | 
         
            +
                                           YOLOAnchorGenerator)
         
     | 
| 4 | 
         
            +
            from .builder import (ANCHOR_GENERATORS, PRIOR_GENERATORS,
         
     | 
| 5 | 
         
            +
                                  build_anchor_generator, build_prior_generator)
         
     | 
| 6 | 
         
            +
            from .point_generator import MlvlPointGenerator, PointGenerator
         
     | 
| 7 | 
         
            +
            from .utils import anchor_inside_flags, calc_region, images_to_levels
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            __all__ = [
         
     | 
| 10 | 
         
            +
                'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
         
     | 
| 11 | 
         
            +
                'PointGenerator', 'images_to_levels', 'calc_region',
         
     | 
| 12 | 
         
            +
                'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator',
         
     | 
| 13 | 
         
            +
                'build_prior_generator', 'PRIOR_GENERATORS', 'MlvlPointGenerator'
         
     | 
| 14 | 
         
            +
            ]
         
     | 
    	
        mmdet/core/anchor/anchor_generator.py
    ADDED
    
    | 
         @@ -0,0 +1,866 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import mmcv
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from torch.nn.modules.utils import _pair
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from .builder import PRIOR_GENERATORS
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 13 | 
         
            +
            class AnchorGenerator:
         
     | 
| 14 | 
         
            +
                """Standard anchor generator for 2D anchor-based detectors.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Args:
         
     | 
| 17 | 
         
            +
                    strides (list[int] | list[tuple[int, int]]): Strides of anchors
         
     | 
| 18 | 
         
            +
                        in multiple feature levels in order (w, h).
         
     | 
| 19 | 
         
            +
                    ratios (list[float]): The list of ratios between the height and width
         
     | 
| 20 | 
         
            +
                        of anchors in a single level.
         
     | 
| 21 | 
         
            +
                    scales (list[int] | None): Anchor scales for anchors in a single level.
         
     | 
| 22 | 
         
            +
                        It cannot be set at the same time if `octave_base_scale` and
         
     | 
| 23 | 
         
            +
                        `scales_per_octave` are set.
         
     | 
| 24 | 
         
            +
                    base_sizes (list[int] | None): The basic sizes
         
     | 
| 25 | 
         
            +
                        of anchors in multiple levels.
         
     | 
| 26 | 
         
            +
                        If None is given, strides will be used as base_sizes.
         
     | 
| 27 | 
         
            +
                        (If strides are non square, the shortest stride is taken.)
         
     | 
| 28 | 
         
            +
                    scale_major (bool): Whether to multiply scales first when generating
         
     | 
| 29 | 
         
            +
                        base anchors. If true, the anchors in the same row will have the
         
     | 
| 30 | 
         
            +
                        same scales. By default it is True in V2.0
         
     | 
| 31 | 
         
            +
                    octave_base_scale (int): The base scale of octave.
         
     | 
| 32 | 
         
            +
                    scales_per_octave (int): Number of scales for each octave.
         
     | 
| 33 | 
         
            +
                        `octave_base_scale` and `scales_per_octave` are usually used in
         
     | 
| 34 | 
         
            +
                        retinanet and the `scales` should be None when they are set.
         
     | 
| 35 | 
         
            +
                    centers (list[tuple[float, float]] | None): The centers of the anchor
         
     | 
| 36 | 
         
            +
                        relative to the feature grid center in multiple feature levels.
         
     | 
| 37 | 
         
            +
                        By default it is set to be None and not used. If a list of tuple of
         
     | 
| 38 | 
         
            +
                        float is given, they will be used to shift the centers of anchors.
         
     | 
| 39 | 
         
            +
                    center_offset (float): The offset of center in proportion to anchors'
         
     | 
| 40 | 
         
            +
                        width and height. By default it is 0 in V2.0.
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                Examples:
         
     | 
| 43 | 
         
            +
                    >>> from mmdet.core import AnchorGenerator
         
     | 
| 44 | 
         
            +
                    >>> self = AnchorGenerator([16], [1.], [1.], [9])
         
     | 
| 45 | 
         
            +
                    >>> all_anchors = self.grid_priors([(2, 2)], device='cpu')
         
     | 
| 46 | 
         
            +
                    >>> print(all_anchors)
         
     | 
| 47 | 
         
            +
                    [tensor([[-4.5000, -4.5000,  4.5000,  4.5000],
         
     | 
| 48 | 
         
            +
                            [11.5000, -4.5000, 20.5000,  4.5000],
         
     | 
| 49 | 
         
            +
                            [-4.5000, 11.5000,  4.5000, 20.5000],
         
     | 
| 50 | 
         
            +
                            [11.5000, 11.5000, 20.5000, 20.5000]])]
         
     | 
| 51 | 
         
            +
                    >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
         
     | 
| 52 | 
         
            +
                    >>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu')
         
     | 
| 53 | 
         
            +
                    >>> print(all_anchors)
         
     | 
| 54 | 
         
            +
                    [tensor([[-4.5000, -4.5000,  4.5000,  4.5000],
         
     | 
| 55 | 
         
            +
                            [11.5000, -4.5000, 20.5000,  4.5000],
         
     | 
| 56 | 
         
            +
                            [-4.5000, 11.5000,  4.5000, 20.5000],
         
     | 
| 57 | 
         
            +
                            [11.5000, 11.5000, 20.5000, 20.5000]]), \
         
     | 
| 58 | 
         
            +
                    tensor([[-9., -9., 9., 9.]])]
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def __init__(self,
         
     | 
| 62 | 
         
            +
                             strides,
         
     | 
| 63 | 
         
            +
                             ratios,
         
     | 
| 64 | 
         
            +
                             scales=None,
         
     | 
| 65 | 
         
            +
                             base_sizes=None,
         
     | 
| 66 | 
         
            +
                             scale_major=True,
         
     | 
| 67 | 
         
            +
                             octave_base_scale=None,
         
     | 
| 68 | 
         
            +
                             scales_per_octave=None,
         
     | 
| 69 | 
         
            +
                             centers=None,
         
     | 
| 70 | 
         
            +
                             center_offset=0.):
         
     | 
| 71 | 
         
            +
                    # check center and center_offset
         
     | 
| 72 | 
         
            +
                    if center_offset != 0:
         
     | 
| 73 | 
         
            +
                        assert centers is None, 'center cannot be set when center_offset' \
         
     | 
| 74 | 
         
            +
                                                f'!=0, {centers} is given.'
         
     | 
| 75 | 
         
            +
                    if not (0 <= center_offset <= 1):
         
     | 
| 76 | 
         
            +
                        raise ValueError('center_offset should be in range [0, 1], '
         
     | 
| 77 | 
         
            +
                                         f'{center_offset} is given.')
         
     | 
| 78 | 
         
            +
                    if centers is not None:
         
     | 
| 79 | 
         
            +
                        assert len(centers) == len(strides), \
         
     | 
| 80 | 
         
            +
                            'The number of strides should be the same as centers, got ' \
         
     | 
| 81 | 
         
            +
                            f'{strides} and {centers}'
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    # calculate base sizes of anchors
         
     | 
| 84 | 
         
            +
                    self.strides = [_pair(stride) for stride in strides]
         
     | 
| 85 | 
         
            +
                    self.base_sizes = [min(stride) for stride in self.strides
         
     | 
| 86 | 
         
            +
                                       ] if base_sizes is None else base_sizes
         
     | 
| 87 | 
         
            +
                    assert len(self.base_sizes) == len(self.strides), \
         
     | 
| 88 | 
         
            +
                        'The number of strides should be the same as base sizes, got ' \
         
     | 
| 89 | 
         
            +
                        f'{self.strides} and {self.base_sizes}'
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # calculate scales of anchors
         
     | 
| 92 | 
         
            +
                    assert ((octave_base_scale is not None
         
     | 
| 93 | 
         
            +
                             and scales_per_octave is not None) ^ (scales is not None)), \
         
     | 
| 94 | 
         
            +
                        'scales and octave_base_scale with scales_per_octave cannot' \
         
     | 
| 95 | 
         
            +
                        ' be set at the same time'
         
     | 
| 96 | 
         
            +
                    if scales is not None:
         
     | 
| 97 | 
         
            +
                        self.scales = torch.Tensor(scales)
         
     | 
| 98 | 
         
            +
                    elif octave_base_scale is not None and scales_per_octave is not None:
         
     | 
| 99 | 
         
            +
                        octave_scales = np.array(
         
     | 
| 100 | 
         
            +
                            [2**(i / scales_per_octave) for i in range(scales_per_octave)])
         
     | 
| 101 | 
         
            +
                        scales = octave_scales * octave_base_scale
         
     | 
| 102 | 
         
            +
                        self.scales = torch.Tensor(scales)
         
     | 
| 103 | 
         
            +
                    else:
         
     | 
| 104 | 
         
            +
                        raise ValueError('Either scales or octave_base_scale with '
         
     | 
| 105 | 
         
            +
                                         'scales_per_octave should be set')
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    self.octave_base_scale = octave_base_scale
         
     | 
| 108 | 
         
            +
                    self.scales_per_octave = scales_per_octave
         
     | 
| 109 | 
         
            +
                    self.ratios = torch.Tensor(ratios)
         
     | 
| 110 | 
         
            +
                    self.scale_major = scale_major
         
     | 
| 111 | 
         
            +
                    self.centers = centers
         
     | 
| 112 | 
         
            +
                    self.center_offset = center_offset
         
     | 
| 113 | 
         
            +
                    self.base_anchors = self.gen_base_anchors()
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                @property
         
     | 
| 116 | 
         
            +
                def num_base_anchors(self):
         
     | 
| 117 | 
         
            +
                    """list[int]: total number of base anchors in a feature grid"""
         
     | 
| 118 | 
         
            +
                    return self.num_base_priors
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                @property
         
     | 
| 121 | 
         
            +
                def num_base_priors(self):
         
     | 
| 122 | 
         
            +
                    """list[int]: The number of priors (anchors) at a point
         
     | 
| 123 | 
         
            +
                    on the feature grid"""
         
     | 
| 124 | 
         
            +
                    return [base_anchors.size(0) for base_anchors in self.base_anchors]
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                @property
         
     | 
| 127 | 
         
            +
                def num_levels(self):
         
     | 
| 128 | 
         
            +
                    """int: number of feature levels that the generator will be applied"""
         
     | 
| 129 | 
         
            +
                    return len(self.strides)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def gen_base_anchors(self):
         
     | 
| 132 | 
         
            +
                    """Generate base anchors.
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    Returns:
         
     | 
| 135 | 
         
            +
                        list(torch.Tensor): Base anchors of a feature grid in multiple \
         
     | 
| 136 | 
         
            +
                            feature levels.
         
     | 
| 137 | 
         
            +
                    """
         
     | 
| 138 | 
         
            +
                    multi_level_base_anchors = []
         
     | 
| 139 | 
         
            +
                    for i, base_size in enumerate(self.base_sizes):
         
     | 
| 140 | 
         
            +
                        center = None
         
     | 
| 141 | 
         
            +
                        if self.centers is not None:
         
     | 
| 142 | 
         
            +
                            center = self.centers[i]
         
     | 
| 143 | 
         
            +
                        multi_level_base_anchors.append(
         
     | 
| 144 | 
         
            +
                            self.gen_single_level_base_anchors(
         
     | 
| 145 | 
         
            +
                                base_size,
         
     | 
| 146 | 
         
            +
                                scales=self.scales,
         
     | 
| 147 | 
         
            +
                                ratios=self.ratios,
         
     | 
| 148 | 
         
            +
                                center=center))
         
     | 
| 149 | 
         
            +
                    return multi_level_base_anchors
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def gen_single_level_base_anchors(self,
         
     | 
| 152 | 
         
            +
                                                  base_size,
         
     | 
| 153 | 
         
            +
                                                  scales,
         
     | 
| 154 | 
         
            +
                                                  ratios,
         
     | 
| 155 | 
         
            +
                                                  center=None):
         
     | 
| 156 | 
         
            +
                    """Generate base anchors of a single level.
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    Args:
         
     | 
| 159 | 
         
            +
                        base_size (int | float): Basic size of an anchor.
         
     | 
| 160 | 
         
            +
                        scales (torch.Tensor): Scales of the anchor.
         
     | 
| 161 | 
         
            +
                        ratios (torch.Tensor): The ratio between between the height
         
     | 
| 162 | 
         
            +
                            and width of anchors in a single level.
         
     | 
| 163 | 
         
            +
                        center (tuple[float], optional): The center of the base anchor
         
     | 
| 164 | 
         
            +
                            related to a single feature grid. Defaults to None.
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    Returns:
         
     | 
| 167 | 
         
            +
                        torch.Tensor: Anchors in a single-level feature maps.
         
     | 
| 168 | 
         
            +
                    """
         
     | 
| 169 | 
         
            +
                    w = base_size
         
     | 
| 170 | 
         
            +
                    h = base_size
         
     | 
| 171 | 
         
            +
                    if center is None:
         
     | 
| 172 | 
         
            +
                        x_center = self.center_offset * w
         
     | 
| 173 | 
         
            +
                        y_center = self.center_offset * h
         
     | 
| 174 | 
         
            +
                    else:
         
     | 
| 175 | 
         
            +
                        x_center, y_center = center
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    h_ratios = torch.sqrt(ratios)
         
     | 
| 178 | 
         
            +
                    w_ratios = 1 / h_ratios
         
     | 
| 179 | 
         
            +
                    if self.scale_major:
         
     | 
| 180 | 
         
            +
                        ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
         
     | 
| 181 | 
         
            +
                        hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
         
     | 
| 182 | 
         
            +
                    else:
         
     | 
| 183 | 
         
            +
                        ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
         
     | 
| 184 | 
         
            +
                        hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # use float anchor and the anchor's center is aligned with the
         
     | 
| 187 | 
         
            +
                    # pixel center
         
     | 
| 188 | 
         
            +
                    base_anchors = [
         
     | 
| 189 | 
         
            +
                        x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
         
     | 
| 190 | 
         
            +
                        y_center + 0.5 * hs
         
     | 
| 191 | 
         
            +
                    ]
         
     | 
| 192 | 
         
            +
                    base_anchors = torch.stack(base_anchors, dim=-1)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    return base_anchors
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def _meshgrid(self, x, y, row_major=True):
         
     | 
| 197 | 
         
            +
                    """Generate mesh grid of x and y.
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    Args:
         
     | 
| 200 | 
         
            +
                        x (torch.Tensor): Grids of x dimension.
         
     | 
| 201 | 
         
            +
                        y (torch.Tensor): Grids of y dimension.
         
     | 
| 202 | 
         
            +
                        row_major (bool, optional): Whether to return y grids first.
         
     | 
| 203 | 
         
            +
                            Defaults to True.
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    Returns:
         
     | 
| 206 | 
         
            +
                        tuple[torch.Tensor]: The mesh grids of x and y.
         
     | 
| 207 | 
         
            +
                    """
         
     | 
| 208 | 
         
            +
                    # use shape instead of len to keep tracing while exporting to onnx
         
     | 
| 209 | 
         
            +
                    xx = x.repeat(y.shape[0])
         
     | 
| 210 | 
         
            +
                    yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
         
     | 
| 211 | 
         
            +
                    if row_major:
         
     | 
| 212 | 
         
            +
                        return xx, yy
         
     | 
| 213 | 
         
            +
                    else:
         
     | 
| 214 | 
         
            +
                        return yy, xx
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def grid_priors(self, featmap_sizes, dtype=torch.float32, device='cuda'):
         
     | 
| 217 | 
         
            +
                    """Generate grid anchors in multiple feature levels.
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    Args:
         
     | 
| 220 | 
         
            +
                        featmap_sizes (list[tuple]): List of feature map sizes in
         
     | 
| 221 | 
         
            +
                            multiple feature levels.
         
     | 
| 222 | 
         
            +
                        dtype (:obj:`torch.dtype`): Dtype of priors.
         
     | 
| 223 | 
         
            +
                            Default: torch.float32.
         
     | 
| 224 | 
         
            +
                        device (str): The device where the anchors will be put on.
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    Return:
         
     | 
| 227 | 
         
            +
                        list[torch.Tensor]: Anchors in multiple feature levels. \
         
     | 
| 228 | 
         
            +
                            The sizes of each tensor should be [N, 4], where \
         
     | 
| 229 | 
         
            +
                            N = width * height * num_base_anchors, width and height \
         
     | 
| 230 | 
         
            +
                            are the sizes of the corresponding feature level, \
         
     | 
| 231 | 
         
            +
                            num_base_anchors is the number of anchors for that level.
         
     | 
| 232 | 
         
            +
                    """
         
     | 
| 233 | 
         
            +
                    assert self.num_levels == len(featmap_sizes)
         
     | 
| 234 | 
         
            +
                    multi_level_anchors = []
         
     | 
| 235 | 
         
            +
                    for i in range(self.num_levels):
         
     | 
| 236 | 
         
            +
                        anchors = self.single_level_grid_priors(
         
     | 
| 237 | 
         
            +
                            featmap_sizes[i], level_idx=i, dtype=dtype, device=device)
         
     | 
| 238 | 
         
            +
                        multi_level_anchors.append(anchors)
         
     | 
| 239 | 
         
            +
                    return multi_level_anchors
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                def single_level_grid_priors(self,
         
     | 
| 242 | 
         
            +
                                             featmap_size,
         
     | 
| 243 | 
         
            +
                                             level_idx,
         
     | 
| 244 | 
         
            +
                                             dtype=torch.float32,
         
     | 
| 245 | 
         
            +
                                             device='cuda'):
         
     | 
| 246 | 
         
            +
                    """Generate grid anchors of a single level.
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    Note:
         
     | 
| 249 | 
         
            +
                        This function is usually called by method ``self.grid_priors``.
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    Args:
         
     | 
| 252 | 
         
            +
                        featmap_size (tuple[int]): Size of the feature maps.
         
     | 
| 253 | 
         
            +
                        level_idx (int): The index of corresponding feature map level.
         
     | 
| 254 | 
         
            +
                        dtype (obj:`torch.dtype`): Date type of points.Defaults to
         
     | 
| 255 | 
         
            +
                            ``torch.float32``.
         
     | 
| 256 | 
         
            +
                        device (str, optional): The device the tensor will be put on.
         
     | 
| 257 | 
         
            +
                            Defaults to 'cuda'.
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    Returns:
         
     | 
| 260 | 
         
            +
                        torch.Tensor: Anchors in the overall feature maps.
         
     | 
| 261 | 
         
            +
                    """
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
         
     | 
| 264 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 265 | 
         
            +
                    stride_w, stride_h = self.strides[level_idx]
         
     | 
| 266 | 
         
            +
                    # First create Range with the default dtype, than convert to
         
     | 
| 267 | 
         
            +
                    # target `dtype` for onnx exporting.
         
     | 
| 268 | 
         
            +
                    shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
         
     | 
| 269 | 
         
            +
                    shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
         
     | 
| 272 | 
         
            +
                    shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
         
     | 
| 273 | 
         
            +
                    # first feat_w elements correspond to the first row of shifts
         
     | 
| 274 | 
         
            +
                    # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
         
     | 
| 275 | 
         
            +
                    # shifted anchors (K, A, 4), reshape to (K*A, 4)
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
         
     | 
| 278 | 
         
            +
                    all_anchors = all_anchors.view(-1, 4)
         
     | 
| 279 | 
         
            +
                    # first A rows correspond to A anchors of (0, 0) in feature map,
         
     | 
| 280 | 
         
            +
                    # then (0, 1), (0, 2), ...
         
     | 
| 281 | 
         
            +
                    return all_anchors
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def sparse_priors(self,
         
     | 
| 284 | 
         
            +
                                  prior_idxs,
         
     | 
| 285 | 
         
            +
                                  featmap_size,
         
     | 
| 286 | 
         
            +
                                  level_idx,
         
     | 
| 287 | 
         
            +
                                  dtype=torch.float32,
         
     | 
| 288 | 
         
            +
                                  device='cuda'):
         
     | 
| 289 | 
         
            +
                    """Generate sparse anchors according to the ``prior_idxs``.
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    Args:
         
     | 
| 292 | 
         
            +
                        prior_idxs (Tensor): The index of corresponding anchors
         
     | 
| 293 | 
         
            +
                            in the feature map.
         
     | 
| 294 | 
         
            +
                        featmap_size (tuple[int]): feature map size arrange as (h, w).
         
     | 
| 295 | 
         
            +
                        level_idx (int): The level index of corresponding feature
         
     | 
| 296 | 
         
            +
                            map.
         
     | 
| 297 | 
         
            +
                        dtype (obj:`torch.dtype`): Date type of points.Defaults to
         
     | 
| 298 | 
         
            +
                            ``torch.float32``.
         
     | 
| 299 | 
         
            +
                        device (obj:`torch.device`): The device where the points is
         
     | 
| 300 | 
         
            +
                            located.
         
     | 
| 301 | 
         
            +
                    Returns:
         
     | 
| 302 | 
         
            +
                        Tensor: Anchor with shape (N, 4), N should be equal to
         
     | 
| 303 | 
         
            +
                            the length of ``prior_idxs``.
         
     | 
| 304 | 
         
            +
                    """
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    height, width = featmap_size
         
     | 
| 307 | 
         
            +
                    num_base_anchors = self.num_base_anchors[level_idx]
         
     | 
| 308 | 
         
            +
                    base_anchor_id = prior_idxs % num_base_anchors
         
     | 
| 309 | 
         
            +
                    x = (prior_idxs //
         
     | 
| 310 | 
         
            +
                         num_base_anchors) % width * self.strides[level_idx][0]
         
     | 
| 311 | 
         
            +
                    y = (prior_idxs // width //
         
     | 
| 312 | 
         
            +
                         num_base_anchors) % height * self.strides[level_idx][1]
         
     | 
| 313 | 
         
            +
                    priors = torch.stack([x, y, x, y], 1).to(dtype).to(device) + \
         
     | 
| 314 | 
         
            +
                        self.base_anchors[level_idx][base_anchor_id, :].to(device)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    return priors
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                def grid_anchors(self, featmap_sizes, device='cuda'):
         
     | 
| 319 | 
         
            +
                    """Generate grid anchors in multiple feature levels.
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    Args:
         
     | 
| 322 | 
         
            +
                        featmap_sizes (list[tuple]): List of feature map sizes in
         
     | 
| 323 | 
         
            +
                            multiple feature levels.
         
     | 
| 324 | 
         
            +
                        device (str): Device where the anchors will be put on.
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    Return:
         
     | 
| 327 | 
         
            +
                        list[torch.Tensor]: Anchors in multiple feature levels. \
         
     | 
| 328 | 
         
            +
                            The sizes of each tensor should be [N, 4], where \
         
     | 
| 329 | 
         
            +
                            N = width * height * num_base_anchors, width and height \
         
     | 
| 330 | 
         
            +
                            are the sizes of the corresponding feature level, \
         
     | 
| 331 | 
         
            +
                            num_base_anchors is the number of anchors for that level.
         
     | 
| 332 | 
         
            +
                    """
         
     | 
| 333 | 
         
            +
                    warnings.warn('``grid_anchors`` would be deprecated soon. '
         
     | 
| 334 | 
         
            +
                                  'Please use ``grid_priors`` ')
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    assert self.num_levels == len(featmap_sizes)
         
     | 
| 337 | 
         
            +
                    multi_level_anchors = []
         
     | 
| 338 | 
         
            +
                    for i in range(self.num_levels):
         
     | 
| 339 | 
         
            +
                        anchors = self.single_level_grid_anchors(
         
     | 
| 340 | 
         
            +
                            self.base_anchors[i].to(device),
         
     | 
| 341 | 
         
            +
                            featmap_sizes[i],
         
     | 
| 342 | 
         
            +
                            self.strides[i],
         
     | 
| 343 | 
         
            +
                            device=device)
         
     | 
| 344 | 
         
            +
                        multi_level_anchors.append(anchors)
         
     | 
| 345 | 
         
            +
                    return multi_level_anchors
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                def single_level_grid_anchors(self,
         
     | 
| 348 | 
         
            +
                                              base_anchors,
         
     | 
| 349 | 
         
            +
                                              featmap_size,
         
     | 
| 350 | 
         
            +
                                              stride=(16, 16),
         
     | 
| 351 | 
         
            +
                                              device='cuda'):
         
     | 
| 352 | 
         
            +
                    """Generate grid anchors of a single level.
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    Note:
         
     | 
| 355 | 
         
            +
                        This function is usually called by method ``self.grid_anchors``.
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    Args:
         
     | 
| 358 | 
         
            +
                        base_anchors (torch.Tensor): The base anchors of a feature grid.
         
     | 
| 359 | 
         
            +
                        featmap_size (tuple[int]): Size of the feature maps.
         
     | 
| 360 | 
         
            +
                        stride (tuple[int], optional): Stride of the feature map in order
         
     | 
| 361 | 
         
            +
                            (w, h). Defaults to (16, 16).
         
     | 
| 362 | 
         
            +
                        device (str, optional): Device the tensor will be put on.
         
     | 
| 363 | 
         
            +
                            Defaults to 'cuda'.
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    Returns:
         
     | 
| 366 | 
         
            +
                        torch.Tensor: Anchors in the overall feature maps.
         
     | 
| 367 | 
         
            +
                    """
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    warnings.warn(
         
     | 
| 370 | 
         
            +
                        '``single_level_grid_anchors`` would be deprecated soon. '
         
     | 
| 371 | 
         
            +
                        'Please use ``single_level_grid_priors`` ')
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    # keep featmap_size as Tensor instead of int, so that we
         
     | 
| 374 | 
         
            +
                    # can convert to ONNX correctly
         
     | 
| 375 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 376 | 
         
            +
                    shift_x = torch.arange(0, feat_w, device=device) * stride[0]
         
     | 
| 377 | 
         
            +
                    shift_y = torch.arange(0, feat_h, device=device) * stride[1]
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
         
     | 
| 380 | 
         
            +
                    shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
         
     | 
| 381 | 
         
            +
                    shifts = shifts.type_as(base_anchors)
         
     | 
| 382 | 
         
            +
                    # first feat_w elements correspond to the first row of shifts
         
     | 
| 383 | 
         
            +
                    # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
         
     | 
| 384 | 
         
            +
                    # shifted anchors (K, A, 4), reshape to (K*A, 4)
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                    all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
         
     | 
| 387 | 
         
            +
                    all_anchors = all_anchors.view(-1, 4)
         
     | 
| 388 | 
         
            +
                    # first A rows correspond to A anchors of (0, 0) in feature map,
         
     | 
| 389 | 
         
            +
                    # then (0, 1), (0, 2), ...
         
     | 
| 390 | 
         
            +
                    return all_anchors
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
         
     | 
| 393 | 
         
            +
                    """Generate valid flags of anchors in multiple feature levels.
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    Args:
         
     | 
| 396 | 
         
            +
                        featmap_sizes (list(tuple)): List of feature map sizes in
         
     | 
| 397 | 
         
            +
                            multiple feature levels.
         
     | 
| 398 | 
         
            +
                        pad_shape (tuple): The padded shape of the image.
         
     | 
| 399 | 
         
            +
                        device (str): Device where the anchors will be put on.
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    Return:
         
     | 
| 402 | 
         
            +
                        list(torch.Tensor): Valid flags of anchors in multiple levels.
         
     | 
| 403 | 
         
            +
                    """
         
     | 
| 404 | 
         
            +
                    assert self.num_levels == len(featmap_sizes)
         
     | 
| 405 | 
         
            +
                    multi_level_flags = []
         
     | 
| 406 | 
         
            +
                    for i in range(self.num_levels):
         
     | 
| 407 | 
         
            +
                        anchor_stride = self.strides[i]
         
     | 
| 408 | 
         
            +
                        feat_h, feat_w = featmap_sizes[i]
         
     | 
| 409 | 
         
            +
                        h, w = pad_shape[:2]
         
     | 
| 410 | 
         
            +
                        valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
         
     | 
| 411 | 
         
            +
                        valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
         
     | 
| 412 | 
         
            +
                        flags = self.single_level_valid_flags((feat_h, feat_w),
         
     | 
| 413 | 
         
            +
                                                              (valid_feat_h, valid_feat_w),
         
     | 
| 414 | 
         
            +
                                                              self.num_base_anchors[i],
         
     | 
| 415 | 
         
            +
                                                              device=device)
         
     | 
| 416 | 
         
            +
                        multi_level_flags.append(flags)
         
     | 
| 417 | 
         
            +
                    return multi_level_flags
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                def single_level_valid_flags(self,
         
     | 
| 420 | 
         
            +
                                             featmap_size,
         
     | 
| 421 | 
         
            +
                                             valid_size,
         
     | 
| 422 | 
         
            +
                                             num_base_anchors,
         
     | 
| 423 | 
         
            +
                                             device='cuda'):
         
     | 
| 424 | 
         
            +
                    """Generate the valid flags of anchor in a single feature map.
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                    Args:
         
     | 
| 427 | 
         
            +
                        featmap_size (tuple[int]): The size of feature maps, arrange
         
     | 
| 428 | 
         
            +
                            as (h, w).
         
     | 
| 429 | 
         
            +
                        valid_size (tuple[int]): The valid size of the feature maps.
         
     | 
| 430 | 
         
            +
                        num_base_anchors (int): The number of base anchors.
         
     | 
| 431 | 
         
            +
                        device (str, optional): Device where the flags will be put on.
         
     | 
| 432 | 
         
            +
                            Defaults to 'cuda'.
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    Returns:
         
     | 
| 435 | 
         
            +
                        torch.Tensor: The valid flags of each anchor in a single level \
         
     | 
| 436 | 
         
            +
                            feature map.
         
     | 
| 437 | 
         
            +
                    """
         
     | 
| 438 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 439 | 
         
            +
                    valid_h, valid_w = valid_size
         
     | 
| 440 | 
         
            +
                    assert valid_h <= feat_h and valid_w <= feat_w
         
     | 
| 441 | 
         
            +
                    valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
         
     | 
| 442 | 
         
            +
                    valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
         
     | 
| 443 | 
         
            +
                    valid_x[:valid_w] = 1
         
     | 
| 444 | 
         
            +
                    valid_y[:valid_h] = 1
         
     | 
| 445 | 
         
            +
                    valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
         
     | 
| 446 | 
         
            +
                    valid = valid_xx & valid_yy
         
     | 
| 447 | 
         
            +
                    valid = valid[:, None].expand(valid.size(0),
         
     | 
| 448 | 
         
            +
                                                  num_base_anchors).contiguous().view(-1)
         
     | 
| 449 | 
         
            +
                    return valid
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                def __repr__(self):
         
     | 
| 452 | 
         
            +
                    """str: a string that describes the module"""
         
     | 
| 453 | 
         
            +
                    indent_str = '    '
         
     | 
| 454 | 
         
            +
                    repr_str = self.__class__.__name__ + '(\n'
         
     | 
| 455 | 
         
            +
                    repr_str += f'{indent_str}strides={self.strides},\n'
         
     | 
| 456 | 
         
            +
                    repr_str += f'{indent_str}ratios={self.ratios},\n'
         
     | 
| 457 | 
         
            +
                    repr_str += f'{indent_str}scales={self.scales},\n'
         
     | 
| 458 | 
         
            +
                    repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
         
     | 
| 459 | 
         
            +
                    repr_str += f'{indent_str}scale_major={self.scale_major},\n'
         
     | 
| 460 | 
         
            +
                    repr_str += f'{indent_str}octave_base_scale='
         
     | 
| 461 | 
         
            +
                    repr_str += f'{self.octave_base_scale},\n'
         
     | 
| 462 | 
         
            +
                    repr_str += f'{indent_str}scales_per_octave='
         
     | 
| 463 | 
         
            +
                    repr_str += f'{self.scales_per_octave},\n'
         
     | 
| 464 | 
         
            +
                    repr_str += f'{indent_str}num_levels={self.num_levels}\n'
         
     | 
| 465 | 
         
            +
                    repr_str += f'{indent_str}centers={self.centers},\n'
         
     | 
| 466 | 
         
            +
                    repr_str += f'{indent_str}center_offset={self.center_offset})'
         
     | 
| 467 | 
         
            +
                    return repr_str
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 471 | 
         
            +
            class SSDAnchorGenerator(AnchorGenerator):
         
     | 
| 472 | 
         
            +
                """Anchor generator for SSD.
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                Args:
         
     | 
| 475 | 
         
            +
                    strides (list[int]  | list[tuple[int, int]]): Strides of anchors
         
     | 
| 476 | 
         
            +
                        in multiple feature levels.
         
     | 
| 477 | 
         
            +
                    ratios (list[float]): The list of ratios between the height and width
         
     | 
| 478 | 
         
            +
                        of anchors in a single level.
         
     | 
| 479 | 
         
            +
                    min_sizes (list[float]): The list of minimum anchor sizes on each
         
     | 
| 480 | 
         
            +
                        level.
         
     | 
| 481 | 
         
            +
                    max_sizes (list[float]): The list of maximum anchor sizes on each
         
     | 
| 482 | 
         
            +
                        level.
         
     | 
| 483 | 
         
            +
                    basesize_ratio_range (tuple(float)): Ratio range of anchors. Being
         
     | 
| 484 | 
         
            +
                        used when not setting min_sizes and max_sizes.
         
     | 
| 485 | 
         
            +
                    input_size (int): Size of feature map, 300 for SSD300, 512 for
         
     | 
| 486 | 
         
            +
                        SSD512. Being used when not setting min_sizes and max_sizes.
         
     | 
| 487 | 
         
            +
                    scale_major (bool): Whether to multiply scales first when generating
         
     | 
| 488 | 
         
            +
                        base anchors. If true, the anchors in the same row will have the
         
     | 
| 489 | 
         
            +
                        same scales. It is always set to be False in SSD.
         
     | 
| 490 | 
         
            +
                """
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                def __init__(self,
         
     | 
| 493 | 
         
            +
                             strides,
         
     | 
| 494 | 
         
            +
                             ratios,
         
     | 
| 495 | 
         
            +
                             min_sizes=None,
         
     | 
| 496 | 
         
            +
                             max_sizes=None,
         
     | 
| 497 | 
         
            +
                             basesize_ratio_range=(0.15, 0.9),
         
     | 
| 498 | 
         
            +
                             input_size=300,
         
     | 
| 499 | 
         
            +
                             scale_major=True):
         
     | 
| 500 | 
         
            +
                    assert len(strides) == len(ratios)
         
     | 
| 501 | 
         
            +
                    assert not (min_sizes is None) ^ (max_sizes is None)
         
     | 
| 502 | 
         
            +
                    self.strides = [_pair(stride) for stride in strides]
         
     | 
| 503 | 
         
            +
                    self.centers = [(stride[0] / 2., stride[1] / 2.)
         
     | 
| 504 | 
         
            +
                                    for stride in self.strides]
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                    if min_sizes is None and max_sizes is None:
         
     | 
| 507 | 
         
            +
                        # use hard code to generate SSD anchors
         
     | 
| 508 | 
         
            +
                        self.input_size = input_size
         
     | 
| 509 | 
         
            +
                        assert mmcv.is_tuple_of(basesize_ratio_range, float)
         
     | 
| 510 | 
         
            +
                        self.basesize_ratio_range = basesize_ratio_range
         
     | 
| 511 | 
         
            +
                        # calculate anchor ratios and sizes
         
     | 
| 512 | 
         
            +
                        min_ratio, max_ratio = basesize_ratio_range
         
     | 
| 513 | 
         
            +
                        min_ratio = int(min_ratio * 100)
         
     | 
| 514 | 
         
            +
                        max_ratio = int(max_ratio * 100)
         
     | 
| 515 | 
         
            +
                        step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
         
     | 
| 516 | 
         
            +
                        min_sizes = []
         
     | 
| 517 | 
         
            +
                        max_sizes = []
         
     | 
| 518 | 
         
            +
                        for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
         
     | 
| 519 | 
         
            +
                            min_sizes.append(int(self.input_size * ratio / 100))
         
     | 
| 520 | 
         
            +
                            max_sizes.append(int(self.input_size * (ratio + step) / 100))
         
     | 
| 521 | 
         
            +
                        if self.input_size == 300:
         
     | 
| 522 | 
         
            +
                            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
         
     | 
| 523 | 
         
            +
                                min_sizes.insert(0, int(self.input_size * 7 / 100))
         
     | 
| 524 | 
         
            +
                                max_sizes.insert(0, int(self.input_size * 15 / 100))
         
     | 
| 525 | 
         
            +
                            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
         
     | 
| 526 | 
         
            +
                                min_sizes.insert(0, int(self.input_size * 10 / 100))
         
     | 
| 527 | 
         
            +
                                max_sizes.insert(0, int(self.input_size * 20 / 100))
         
     | 
| 528 | 
         
            +
                            else:
         
     | 
| 529 | 
         
            +
                                raise ValueError(
         
     | 
| 530 | 
         
            +
                                    'basesize_ratio_range[0] should be either 0.15'
         
     | 
| 531 | 
         
            +
                                    'or 0.2 when input_size is 300, got '
         
     | 
| 532 | 
         
            +
                                    f'{basesize_ratio_range[0]}.')
         
     | 
| 533 | 
         
            +
                        elif self.input_size == 512:
         
     | 
| 534 | 
         
            +
                            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
         
     | 
| 535 | 
         
            +
                                min_sizes.insert(0, int(self.input_size * 4 / 100))
         
     | 
| 536 | 
         
            +
                                max_sizes.insert(0, int(self.input_size * 10 / 100))
         
     | 
| 537 | 
         
            +
                            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
         
     | 
| 538 | 
         
            +
                                min_sizes.insert(0, int(self.input_size * 7 / 100))
         
     | 
| 539 | 
         
            +
                                max_sizes.insert(0, int(self.input_size * 15 / 100))
         
     | 
| 540 | 
         
            +
                            else:
         
     | 
| 541 | 
         
            +
                                raise ValueError(
         
     | 
| 542 | 
         
            +
                                    'When not setting min_sizes and max_sizes,'
         
     | 
| 543 | 
         
            +
                                    'basesize_ratio_range[0] should be either 0.1'
         
     | 
| 544 | 
         
            +
                                    'or 0.15 when input_size is 512, got'
         
     | 
| 545 | 
         
            +
                                    f' {basesize_ratio_range[0]}.')
         
     | 
| 546 | 
         
            +
                        else:
         
     | 
| 547 | 
         
            +
                            raise ValueError(
         
     | 
| 548 | 
         
            +
                                'Only support 300 or 512 in SSDAnchorGenerator when '
         
     | 
| 549 | 
         
            +
                                'not setting min_sizes and max_sizes, '
         
     | 
| 550 | 
         
            +
                                f'got {self.input_size}.')
         
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
                    assert len(min_sizes) == len(max_sizes) == len(strides)
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
                    anchor_ratios = []
         
     | 
| 555 | 
         
            +
                    anchor_scales = []
         
     | 
| 556 | 
         
            +
                    for k in range(len(self.strides)):
         
     | 
| 557 | 
         
            +
                        scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
         
     | 
| 558 | 
         
            +
                        anchor_ratio = [1.]
         
     | 
| 559 | 
         
            +
                        for r in ratios[k]:
         
     | 
| 560 | 
         
            +
                            anchor_ratio += [1 / r, r]  # 4 or 6 ratio
         
     | 
| 561 | 
         
            +
                        anchor_ratios.append(torch.Tensor(anchor_ratio))
         
     | 
| 562 | 
         
            +
                        anchor_scales.append(torch.Tensor(scales))
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                    self.base_sizes = min_sizes
         
     | 
| 565 | 
         
            +
                    self.scales = anchor_scales
         
     | 
| 566 | 
         
            +
                    self.ratios = anchor_ratios
         
     | 
| 567 | 
         
            +
                    self.scale_major = scale_major
         
     | 
| 568 | 
         
            +
                    self.center_offset = 0
         
     | 
| 569 | 
         
            +
                    self.base_anchors = self.gen_base_anchors()
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                def gen_base_anchors(self):
         
     | 
| 572 | 
         
            +
                    """Generate base anchors.
         
     | 
| 573 | 
         
            +
             
     | 
| 574 | 
         
            +
                    Returns:
         
     | 
| 575 | 
         
            +
                        list(torch.Tensor): Base anchors of a feature grid in multiple \
         
     | 
| 576 | 
         
            +
                            feature levels.
         
     | 
| 577 | 
         
            +
                    """
         
     | 
| 578 | 
         
            +
                    multi_level_base_anchors = []
         
     | 
| 579 | 
         
            +
                    for i, base_size in enumerate(self.base_sizes):
         
     | 
| 580 | 
         
            +
                        base_anchors = self.gen_single_level_base_anchors(
         
     | 
| 581 | 
         
            +
                            base_size,
         
     | 
| 582 | 
         
            +
                            scales=self.scales[i],
         
     | 
| 583 | 
         
            +
                            ratios=self.ratios[i],
         
     | 
| 584 | 
         
            +
                            center=self.centers[i])
         
     | 
| 585 | 
         
            +
                        indices = list(range(len(self.ratios[i])))
         
     | 
| 586 | 
         
            +
                        indices.insert(1, len(indices))
         
     | 
| 587 | 
         
            +
                        base_anchors = torch.index_select(base_anchors, 0,
         
     | 
| 588 | 
         
            +
                                                          torch.LongTensor(indices))
         
     | 
| 589 | 
         
            +
                        multi_level_base_anchors.append(base_anchors)
         
     | 
| 590 | 
         
            +
                    return multi_level_base_anchors
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                def __repr__(self):
         
     | 
| 593 | 
         
            +
                    """str: a string that describes the module"""
         
     | 
| 594 | 
         
            +
                    indent_str = '    '
         
     | 
| 595 | 
         
            +
                    repr_str = self.__class__.__name__ + '(\n'
         
     | 
| 596 | 
         
            +
                    repr_str += f'{indent_str}strides={self.strides},\n'
         
     | 
| 597 | 
         
            +
                    repr_str += f'{indent_str}scales={self.scales},\n'
         
     | 
| 598 | 
         
            +
                    repr_str += f'{indent_str}scale_major={self.scale_major},\n'
         
     | 
| 599 | 
         
            +
                    repr_str += f'{indent_str}input_size={self.input_size},\n'
         
     | 
| 600 | 
         
            +
                    repr_str += f'{indent_str}scales={self.scales},\n'
         
     | 
| 601 | 
         
            +
                    repr_str += f'{indent_str}ratios={self.ratios},\n'
         
     | 
| 602 | 
         
            +
                    repr_str += f'{indent_str}num_levels={self.num_levels},\n'
         
     | 
| 603 | 
         
            +
                    repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
         
     | 
| 604 | 
         
            +
                    repr_str += f'{indent_str}basesize_ratio_range='
         
     | 
| 605 | 
         
            +
                    repr_str += f'{self.basesize_ratio_range})'
         
     | 
| 606 | 
         
            +
                    return repr_str
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 610 | 
         
            +
            class LegacyAnchorGenerator(AnchorGenerator):
         
     | 
| 611 | 
         
            +
                """Legacy anchor generator used in MMDetection V1.x.
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                Note:
         
     | 
| 614 | 
         
            +
                    Difference to the V2.0 anchor generator:
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
                    1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
         
     | 
| 617 | 
         
            +
                    2. The width/height are minused by 1 when calculating the anchors' \
         
     | 
| 618 | 
         
            +
                        centers and corners to meet the V1.x coordinate system.
         
     | 
| 619 | 
         
            +
                    3. The anchors' corners are quantized.
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                Args:
         
     | 
| 622 | 
         
            +
                    strides (list[int] | list[tuple[int]]): Strides of anchors
         
     | 
| 623 | 
         
            +
                        in multiple feature levels.
         
     | 
| 624 | 
         
            +
                    ratios (list[float]): The list of ratios between the height and width
         
     | 
| 625 | 
         
            +
                        of anchors in a single level.
         
     | 
| 626 | 
         
            +
                    scales (list[int] | None): Anchor scales for anchors in a single level.
         
     | 
| 627 | 
         
            +
                        It cannot be set at the same time if `octave_base_scale` and
         
     | 
| 628 | 
         
            +
                        `scales_per_octave` are set.
         
     | 
| 629 | 
         
            +
                    base_sizes (list[int]): The basic sizes of anchors in multiple levels.
         
     | 
| 630 | 
         
            +
                        If None is given, strides will be used to generate base_sizes.
         
     | 
| 631 | 
         
            +
                    scale_major (bool): Whether to multiply scales first when generating
         
     | 
| 632 | 
         
            +
                        base anchors. If true, the anchors in the same row will have the
         
     | 
| 633 | 
         
            +
                        same scales. By default it is True in V2.0
         
     | 
| 634 | 
         
            +
                    octave_base_scale (int): The base scale of octave.
         
     | 
| 635 | 
         
            +
                    scales_per_octave (int): Number of scales for each octave.
         
     | 
| 636 | 
         
            +
                        `octave_base_scale` and `scales_per_octave` are usually used in
         
     | 
| 637 | 
         
            +
                        retinanet and the `scales` should be None when they are set.
         
     | 
| 638 | 
         
            +
                    centers (list[tuple[float, float]] | None): The centers of the anchor
         
     | 
| 639 | 
         
            +
                        relative to the feature grid center in multiple feature levels.
         
     | 
| 640 | 
         
            +
                        By default it is set to be None and not used. It a list of float
         
     | 
| 641 | 
         
            +
                        is given, this list will be used to shift the centers of anchors.
         
     | 
| 642 | 
         
            +
                    center_offset (float): The offset of center in proportion to anchors'
         
     | 
| 643 | 
         
            +
                        width and height. By default it is 0.5 in V2.0 but it should be 0.5
         
     | 
| 644 | 
         
            +
                        in v1.x models.
         
     | 
| 645 | 
         
            +
             
     | 
| 646 | 
         
            +
                Examples:
         
     | 
| 647 | 
         
            +
                    >>> from mmdet.core import LegacyAnchorGenerator
         
     | 
| 648 | 
         
            +
                    >>> self = LegacyAnchorGenerator(
         
     | 
| 649 | 
         
            +
                    >>>     [16], [1.], [1.], [9], center_offset=0.5)
         
     | 
| 650 | 
         
            +
                    >>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
         
     | 
| 651 | 
         
            +
                    >>> print(all_anchors)
         
     | 
| 652 | 
         
            +
                    [tensor([[ 0.,  0.,  8.,  8.],
         
     | 
| 653 | 
         
            +
                            [16.,  0., 24.,  8.],
         
     | 
| 654 | 
         
            +
                            [ 0., 16.,  8., 24.],
         
     | 
| 655 | 
         
            +
                            [16., 16., 24., 24.]])]
         
     | 
| 656 | 
         
            +
                """
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                def gen_single_level_base_anchors(self,
         
     | 
| 659 | 
         
            +
                                                  base_size,
         
     | 
| 660 | 
         
            +
                                                  scales,
         
     | 
| 661 | 
         
            +
                                                  ratios,
         
     | 
| 662 | 
         
            +
                                                  center=None):
         
     | 
| 663 | 
         
            +
                    """Generate base anchors of a single level.
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                    Note:
         
     | 
| 666 | 
         
            +
                        The width/height of anchors are minused by 1 when calculating \
         
     | 
| 667 | 
         
            +
                            the centers and corners to meet the V1.x coordinate system.
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                    Args:
         
     | 
| 670 | 
         
            +
                        base_size (int | float): Basic size of an anchor.
         
     | 
| 671 | 
         
            +
                        scales (torch.Tensor): Scales of the anchor.
         
     | 
| 672 | 
         
            +
                        ratios (torch.Tensor): The ratio between between the height.
         
     | 
| 673 | 
         
            +
                            and width of anchors in a single level.
         
     | 
| 674 | 
         
            +
                        center (tuple[float], optional): The center of the base anchor
         
     | 
| 675 | 
         
            +
                            related to a single feature grid. Defaults to None.
         
     | 
| 676 | 
         
            +
             
     | 
| 677 | 
         
            +
                    Returns:
         
     | 
| 678 | 
         
            +
                        torch.Tensor: Anchors in a single-level feature map.
         
     | 
| 679 | 
         
            +
                    """
         
     | 
| 680 | 
         
            +
                    w = base_size
         
     | 
| 681 | 
         
            +
                    h = base_size
         
     | 
| 682 | 
         
            +
                    if center is None:
         
     | 
| 683 | 
         
            +
                        x_center = self.center_offset * (w - 1)
         
     | 
| 684 | 
         
            +
                        y_center = self.center_offset * (h - 1)
         
     | 
| 685 | 
         
            +
                    else:
         
     | 
| 686 | 
         
            +
                        x_center, y_center = center
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                    h_ratios = torch.sqrt(ratios)
         
     | 
| 689 | 
         
            +
                    w_ratios = 1 / h_ratios
         
     | 
| 690 | 
         
            +
                    if self.scale_major:
         
     | 
| 691 | 
         
            +
                        ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
         
     | 
| 692 | 
         
            +
                        hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
         
     | 
| 693 | 
         
            +
                    else:
         
     | 
| 694 | 
         
            +
                        ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
         
     | 
| 695 | 
         
            +
                        hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                    # use float anchor and the anchor's center is aligned with the
         
     | 
| 698 | 
         
            +
                    # pixel center
         
     | 
| 699 | 
         
            +
                    base_anchors = [
         
     | 
| 700 | 
         
            +
                        x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
         
     | 
| 701 | 
         
            +
                        x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
         
     | 
| 702 | 
         
            +
                    ]
         
     | 
| 703 | 
         
            +
                    base_anchors = torch.stack(base_anchors, dim=-1).round()
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
                    return base_anchors
         
     | 
| 706 | 
         
            +
             
     | 
| 707 | 
         
            +
             
     | 
| 708 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 709 | 
         
            +
            class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
         
     | 
| 710 | 
         
            +
                """Legacy anchor generator used in MMDetection V1.x.
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
         
     | 
| 713 | 
         
            +
                can be found in `LegacyAnchorGenerator`.
         
     | 
| 714 | 
         
            +
                """
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                def __init__(self,
         
     | 
| 717 | 
         
            +
                             strides,
         
     | 
| 718 | 
         
            +
                             ratios,
         
     | 
| 719 | 
         
            +
                             basesize_ratio_range,
         
     | 
| 720 | 
         
            +
                             input_size=300,
         
     | 
| 721 | 
         
            +
                             scale_major=True):
         
     | 
| 722 | 
         
            +
                    super(LegacySSDAnchorGenerator, self).__init__(
         
     | 
| 723 | 
         
            +
                        strides=strides,
         
     | 
| 724 | 
         
            +
                        ratios=ratios,
         
     | 
| 725 | 
         
            +
                        basesize_ratio_range=basesize_ratio_range,
         
     | 
| 726 | 
         
            +
                        input_size=input_size,
         
     | 
| 727 | 
         
            +
                        scale_major=scale_major)
         
     | 
| 728 | 
         
            +
                    self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
         
     | 
| 729 | 
         
            +
                                    for stride in strides]
         
     | 
| 730 | 
         
            +
                    self.base_anchors = self.gen_base_anchors()
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
             
     | 
| 733 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 734 | 
         
            +
            class YOLOAnchorGenerator(AnchorGenerator):
         
     | 
| 735 | 
         
            +
                """Anchor generator for YOLO.
         
     | 
| 736 | 
         
            +
             
     | 
| 737 | 
         
            +
                Args:
         
     | 
| 738 | 
         
            +
                    strides (list[int] | list[tuple[int, int]]): Strides of anchors
         
     | 
| 739 | 
         
            +
                        in multiple feature levels.
         
     | 
| 740 | 
         
            +
                    base_sizes (list[list[tuple[int, int]]]): The basic sizes
         
     | 
| 741 | 
         
            +
                        of anchors in multiple levels.
         
     | 
| 742 | 
         
            +
                """
         
     | 
| 743 | 
         
            +
             
     | 
| 744 | 
         
            +
                def __init__(self, strides, base_sizes):
         
     | 
| 745 | 
         
            +
                    self.strides = [_pair(stride) for stride in strides]
         
     | 
| 746 | 
         
            +
                    self.centers = [(stride[0] / 2., stride[1] / 2.)
         
     | 
| 747 | 
         
            +
                                    for stride in self.strides]
         
     | 
| 748 | 
         
            +
                    self.base_sizes = []
         
     | 
| 749 | 
         
            +
                    num_anchor_per_level = len(base_sizes[0])
         
     | 
| 750 | 
         
            +
                    for base_sizes_per_level in base_sizes:
         
     | 
| 751 | 
         
            +
                        assert num_anchor_per_level == len(base_sizes_per_level)
         
     | 
| 752 | 
         
            +
                        self.base_sizes.append(
         
     | 
| 753 | 
         
            +
                            [_pair(base_size) for base_size in base_sizes_per_level])
         
     | 
| 754 | 
         
            +
                    self.base_anchors = self.gen_base_anchors()
         
     | 
| 755 | 
         
            +
             
     | 
| 756 | 
         
            +
                @property
         
     | 
| 757 | 
         
            +
                def num_levels(self):
         
     | 
| 758 | 
         
            +
                    """int: number of feature levels that the generator will be applied"""
         
     | 
| 759 | 
         
            +
                    return len(self.base_sizes)
         
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
                def gen_base_anchors(self):
         
     | 
| 762 | 
         
            +
                    """Generate base anchors.
         
     | 
| 763 | 
         
            +
             
     | 
| 764 | 
         
            +
                    Returns:
         
     | 
| 765 | 
         
            +
                        list(torch.Tensor): Base anchors of a feature grid in multiple \
         
     | 
| 766 | 
         
            +
                            feature levels.
         
     | 
| 767 | 
         
            +
                    """
         
     | 
| 768 | 
         
            +
                    multi_level_base_anchors = []
         
     | 
| 769 | 
         
            +
                    for i, base_sizes_per_level in enumerate(self.base_sizes):
         
     | 
| 770 | 
         
            +
                        center = None
         
     | 
| 771 | 
         
            +
                        if self.centers is not None:
         
     | 
| 772 | 
         
            +
                            center = self.centers[i]
         
     | 
| 773 | 
         
            +
                        multi_level_base_anchors.append(
         
     | 
| 774 | 
         
            +
                            self.gen_single_level_base_anchors(base_sizes_per_level,
         
     | 
| 775 | 
         
            +
                                                               center))
         
     | 
| 776 | 
         
            +
                    return multi_level_base_anchors
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
                def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
         
     | 
| 779 | 
         
            +
                    """Generate base anchors of a single level.
         
     | 
| 780 | 
         
            +
             
     | 
| 781 | 
         
            +
                    Args:
         
     | 
| 782 | 
         
            +
                        base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
         
     | 
| 783 | 
         
            +
                            anchors.
         
     | 
| 784 | 
         
            +
                        center (tuple[float], optional): The center of the base anchor
         
     | 
| 785 | 
         
            +
                            related to a single feature grid. Defaults to None.
         
     | 
| 786 | 
         
            +
             
     | 
| 787 | 
         
            +
                    Returns:
         
     | 
| 788 | 
         
            +
                        torch.Tensor: Anchors in a single-level feature maps.
         
     | 
| 789 | 
         
            +
                    """
         
     | 
| 790 | 
         
            +
                    x_center, y_center = center
         
     | 
| 791 | 
         
            +
                    base_anchors = []
         
     | 
| 792 | 
         
            +
                    for base_size in base_sizes_per_level:
         
     | 
| 793 | 
         
            +
                        w, h = base_size
         
     | 
| 794 | 
         
            +
             
     | 
| 795 | 
         
            +
                        # use float anchor and the anchor's center is aligned with the
         
     | 
| 796 | 
         
            +
                        # pixel center
         
     | 
| 797 | 
         
            +
                        base_anchor = torch.Tensor([
         
     | 
| 798 | 
         
            +
                            x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
         
     | 
| 799 | 
         
            +
                            y_center + 0.5 * h
         
     | 
| 800 | 
         
            +
                        ])
         
     | 
| 801 | 
         
            +
                        base_anchors.append(base_anchor)
         
     | 
| 802 | 
         
            +
                    base_anchors = torch.stack(base_anchors, dim=0)
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                    return base_anchors
         
     | 
| 805 | 
         
            +
             
     | 
| 806 | 
         
            +
                def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
         
     | 
| 807 | 
         
            +
                    """Generate responsible anchor flags of grid cells in multiple scales.
         
     | 
| 808 | 
         
            +
             
     | 
| 809 | 
         
            +
                    Args:
         
     | 
| 810 | 
         
            +
                        featmap_sizes (list(tuple)): List of feature map sizes in multiple
         
     | 
| 811 | 
         
            +
                            feature levels.
         
     | 
| 812 | 
         
            +
                        gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
         
     | 
| 813 | 
         
            +
                        device (str): Device where the anchors will be put on.
         
     | 
| 814 | 
         
            +
             
     | 
| 815 | 
         
            +
                    Return:
         
     | 
| 816 | 
         
            +
                        list(torch.Tensor): responsible flags of anchors in multiple level
         
     | 
| 817 | 
         
            +
                    """
         
     | 
| 818 | 
         
            +
                    assert self.num_levels == len(featmap_sizes)
         
     | 
| 819 | 
         
            +
                    multi_level_responsible_flags = []
         
     | 
| 820 | 
         
            +
                    for i in range(self.num_levels):
         
     | 
| 821 | 
         
            +
                        anchor_stride = self.strides[i]
         
     | 
| 822 | 
         
            +
                        flags = self.single_level_responsible_flags(
         
     | 
| 823 | 
         
            +
                            featmap_sizes[i],
         
     | 
| 824 | 
         
            +
                            gt_bboxes,
         
     | 
| 825 | 
         
            +
                            anchor_stride,
         
     | 
| 826 | 
         
            +
                            self.num_base_anchors[i],
         
     | 
| 827 | 
         
            +
                            device=device)
         
     | 
| 828 | 
         
            +
                        multi_level_responsible_flags.append(flags)
         
     | 
| 829 | 
         
            +
                    return multi_level_responsible_flags
         
     | 
| 830 | 
         
            +
             
     | 
| 831 | 
         
            +
                def single_level_responsible_flags(self,
         
     | 
| 832 | 
         
            +
                                                   featmap_size,
         
     | 
| 833 | 
         
            +
                                                   gt_bboxes,
         
     | 
| 834 | 
         
            +
                                                   stride,
         
     | 
| 835 | 
         
            +
                                                   num_base_anchors,
         
     | 
| 836 | 
         
            +
                                                   device='cuda'):
         
     | 
| 837 | 
         
            +
                    """Generate the responsible flags of anchor in a single feature map.
         
     | 
| 838 | 
         
            +
             
     | 
| 839 | 
         
            +
                    Args:
         
     | 
| 840 | 
         
            +
                        featmap_size (tuple[int]): The size of feature maps.
         
     | 
| 841 | 
         
            +
                        gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
         
     | 
| 842 | 
         
            +
                        stride (tuple(int)): stride of current level
         
     | 
| 843 | 
         
            +
                        num_base_anchors (int): The number of base anchors.
         
     | 
| 844 | 
         
            +
                        device (str, optional): Device where the flags will be put on.
         
     | 
| 845 | 
         
            +
                            Defaults to 'cuda'.
         
     | 
| 846 | 
         
            +
             
     | 
| 847 | 
         
            +
                    Returns:
         
     | 
| 848 | 
         
            +
                        torch.Tensor: The valid flags of each anchor in a single level \
         
     | 
| 849 | 
         
            +
                            feature map.
         
     | 
| 850 | 
         
            +
                    """
         
     | 
| 851 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 852 | 
         
            +
                    gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
         
     | 
| 853 | 
         
            +
                    gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
         
     | 
| 854 | 
         
            +
                    gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
         
     | 
| 855 | 
         
            +
                    gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
         
     | 
| 856 | 
         
            +
             
     | 
| 857 | 
         
            +
                    # row major indexing
         
     | 
| 858 | 
         
            +
                    gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
            +
                    responsible_grid = torch.zeros(
         
     | 
| 861 | 
         
            +
                        feat_h * feat_w, dtype=torch.uint8, device=device)
         
     | 
| 862 | 
         
            +
                    responsible_grid[gt_bboxes_grid_idx] = 1
         
     | 
| 863 | 
         
            +
             
     | 
| 864 | 
         
            +
                    responsible_grid = responsible_grid[:, None].expand(
         
     | 
| 865 | 
         
            +
                        responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
         
     | 
| 866 | 
         
            +
                    return responsible_grid
         
     | 
    	
        mmdet/core/anchor/builder.py
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from mmcv.utils import Registry, build_from_cfg
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            PRIOR_GENERATORS = Registry('Generator for anchors and points')
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ANCHOR_GENERATORS = PRIOR_GENERATORS
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def build_prior_generator(cfg, default_args=None):
         
     | 
| 12 | 
         
            +
                return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def build_anchor_generator(cfg, default_args=None):
         
     | 
| 16 | 
         
            +
                warnings.warn(
         
     | 
| 17 | 
         
            +
                    '``build_anchor_generator`` would be deprecated soon, please use '
         
     | 
| 18 | 
         
            +
                    '``build_prior_generator`` ')
         
     | 
| 19 | 
         
            +
                return build_prior_generator(cfg, default_args=default_args)
         
     | 
    	
        mmdet/core/anchor/point_generator.py
    ADDED
    
    | 
         @@ -0,0 +1,263 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torch.nn.modules.utils import _pair
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from .builder import PRIOR_GENERATORS
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 10 | 
         
            +
            class PointGenerator:
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def _meshgrid(self, x, y, row_major=True):
         
     | 
| 13 | 
         
            +
                    xx = x.repeat(len(y))
         
     | 
| 14 | 
         
            +
                    yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
         
     | 
| 15 | 
         
            +
                    if row_major:
         
     | 
| 16 | 
         
            +
                        return xx, yy
         
     | 
| 17 | 
         
            +
                    else:
         
     | 
| 18 | 
         
            +
                        return yy, xx
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def grid_points(self, featmap_size, stride=16, device='cuda'):
         
     | 
| 21 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 22 | 
         
            +
                    shift_x = torch.arange(0., feat_w, device=device) * stride
         
     | 
| 23 | 
         
            +
                    shift_y = torch.arange(0., feat_h, device=device) * stride
         
     | 
| 24 | 
         
            +
                    shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
         
     | 
| 25 | 
         
            +
                    stride = shift_x.new_full((shift_xx.shape[0], ), stride)
         
     | 
| 26 | 
         
            +
                    shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
         
     | 
| 27 | 
         
            +
                    all_points = shifts.to(device)
         
     | 
| 28 | 
         
            +
                    return all_points
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def valid_flags(self, featmap_size, valid_size, device='cuda'):
         
     | 
| 31 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 32 | 
         
            +
                    valid_h, valid_w = valid_size
         
     | 
| 33 | 
         
            +
                    assert valid_h <= feat_h and valid_w <= feat_w
         
     | 
| 34 | 
         
            +
                    valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
         
     | 
| 35 | 
         
            +
                    valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
         
     | 
| 36 | 
         
            +
                    valid_x[:valid_w] = 1
         
     | 
| 37 | 
         
            +
                    valid_y[:valid_h] = 1
         
     | 
| 38 | 
         
            +
                    valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
         
     | 
| 39 | 
         
            +
                    valid = valid_xx & valid_yy
         
     | 
| 40 | 
         
            +
                    return valid
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            @PRIOR_GENERATORS.register_module()
         
     | 
| 44 | 
         
            +
            class MlvlPointGenerator:
         
     | 
| 45 | 
         
            +
                """Standard points generator for multi-level (Mlvl) feature maps in 2D
         
     | 
| 46 | 
         
            +
                points-based detectors.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                Args:
         
     | 
| 49 | 
         
            +
                    strides (list[int] | list[tuple[int, int]]): Strides of anchors
         
     | 
| 50 | 
         
            +
                        in multiple feature levels in order (w, h).
         
     | 
| 51 | 
         
            +
                    offset (float): The offset of points, the value is normalized with
         
     | 
| 52 | 
         
            +
                        corresponding stride. Defaults to 0.5.
         
     | 
| 53 | 
         
            +
                """
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def __init__(self, strides, offset=0.5):
         
     | 
| 56 | 
         
            +
                    self.strides = [_pair(stride) for stride in strides]
         
     | 
| 57 | 
         
            +
                    self.offset = offset
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                @property
         
     | 
| 60 | 
         
            +
                def num_levels(self):
         
     | 
| 61 | 
         
            +
                    """int: number of feature levels that the generator will be applied"""
         
     | 
| 62 | 
         
            +
                    return len(self.strides)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                @property
         
     | 
| 65 | 
         
            +
                def num_base_priors(self):
         
     | 
| 66 | 
         
            +
                    """list[int]: The number of priors (points) at a point
         
     | 
| 67 | 
         
            +
                    on the feature grid"""
         
     | 
| 68 | 
         
            +
                    return [1 for _ in range(len(self.strides))]
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def _meshgrid(self, x, y, row_major=True):
         
     | 
| 71 | 
         
            +
                    yy, xx = torch.meshgrid(y, x)
         
     | 
| 72 | 
         
            +
                    if row_major:
         
     | 
| 73 | 
         
            +
                        # warning .flatten() would cause error in ONNX exporting
         
     | 
| 74 | 
         
            +
                        # have to use reshape here
         
     | 
| 75 | 
         
            +
                        return xx.reshape(-1), yy.reshape(-1)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    else:
         
     | 
| 78 | 
         
            +
                        return yy.reshape(-1), xx.reshape(-1)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                def grid_priors(self,
         
     | 
| 81 | 
         
            +
                                featmap_sizes,
         
     | 
| 82 | 
         
            +
                                dtype=torch.float32,
         
     | 
| 83 | 
         
            +
                                device='cuda',
         
     | 
| 84 | 
         
            +
                                with_stride=False):
         
     | 
| 85 | 
         
            +
                    """Generate grid points of multiple feature levels.
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    Args:
         
     | 
| 88 | 
         
            +
                        featmap_sizes (list[tuple]): List of feature map sizes in
         
     | 
| 89 | 
         
            +
                            multiple feature levels, each size arrange as
         
     | 
| 90 | 
         
            +
                            as (h, w).
         
     | 
| 91 | 
         
            +
                        dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
         
     | 
| 92 | 
         
            +
                        device (str): The device where the anchors will be put on.
         
     | 
| 93 | 
         
            +
                        with_stride (bool): Whether to concatenate the stride to
         
     | 
| 94 | 
         
            +
                            the last dimension of points.
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    Return:
         
     | 
| 97 | 
         
            +
                        list[torch.Tensor]: Points of  multiple feature levels.
         
     | 
| 98 | 
         
            +
                        The sizes of each tensor should be (N, 2) when with stride is
         
     | 
| 99 | 
         
            +
                        ``False``, where N = width * height, width and height
         
     | 
| 100 | 
         
            +
                        are the sizes of the corresponding feature level,
         
     | 
| 101 | 
         
            +
                        and the last dimension 2 represent (coord_x, coord_y),
         
     | 
| 102 | 
         
            +
                        otherwise the shape should be (N, 4),
         
     | 
| 103 | 
         
            +
                        and the last dimension 4 represent
         
     | 
| 104 | 
         
            +
                        (coord_x, coord_y, stride_w, stride_h).
         
     | 
| 105 | 
         
            +
                    """
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    assert self.num_levels == len(featmap_sizes)
         
     | 
| 108 | 
         
            +
                    multi_level_priors = []
         
     | 
| 109 | 
         
            +
                    for i in range(self.num_levels):
         
     | 
| 110 | 
         
            +
                        priors = self.single_level_grid_priors(
         
     | 
| 111 | 
         
            +
                            featmap_sizes[i],
         
     | 
| 112 | 
         
            +
                            level_idx=i,
         
     | 
| 113 | 
         
            +
                            dtype=dtype,
         
     | 
| 114 | 
         
            +
                            device=device,
         
     | 
| 115 | 
         
            +
                            with_stride=with_stride)
         
     | 
| 116 | 
         
            +
                        multi_level_priors.append(priors)
         
     | 
| 117 | 
         
            +
                    return multi_level_priors
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def single_level_grid_priors(self,
         
     | 
| 120 | 
         
            +
                                             featmap_size,
         
     | 
| 121 | 
         
            +
                                             level_idx,
         
     | 
| 122 | 
         
            +
                                             dtype=torch.float32,
         
     | 
| 123 | 
         
            +
                                             device='cuda',
         
     | 
| 124 | 
         
            +
                                             with_stride=False):
         
     | 
| 125 | 
         
            +
                    """Generate grid Points of a single level.
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    Note:
         
     | 
| 128 | 
         
            +
                        This function is usually called by method ``self.grid_priors``.
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    Args:
         
     | 
| 131 | 
         
            +
                        featmap_size (tuple[int]): Size of the feature maps, arrange as
         
     | 
| 132 | 
         
            +
                            (h, w).
         
     | 
| 133 | 
         
            +
                        level_idx (int): The index of corresponding feature map level.
         
     | 
| 134 | 
         
            +
                        dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
         
     | 
| 135 | 
         
            +
                        device (str, optional): The device the tensor will be put on.
         
     | 
| 136 | 
         
            +
                            Defaults to 'cuda'.
         
     | 
| 137 | 
         
            +
                        with_stride (bool): Concatenate the stride to the last dimension
         
     | 
| 138 | 
         
            +
                            of points.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    Return:
         
     | 
| 141 | 
         
            +
                        Tensor: Points of single feature levels.
         
     | 
| 142 | 
         
            +
                        The shape of tensor should be (N, 2) when with stride is
         
     | 
| 143 | 
         
            +
                        ``False``, where N = width * height, width and height
         
     | 
| 144 | 
         
            +
                        are the sizes of the corresponding feature level,
         
     | 
| 145 | 
         
            +
                        and the last dimension 2 represent (coord_x, coord_y),
         
     | 
| 146 | 
         
            +
                        otherwise the shape should be (N, 4),
         
     | 
| 147 | 
         
            +
                        and the last dimension 4 represent
         
     | 
| 148 | 
         
            +
                        (coord_x, coord_y, stride_w, stride_h).
         
     | 
| 149 | 
         
            +
                    """
         
     | 
| 150 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 151 | 
         
            +
                    stride_w, stride_h = self.strides[level_idx]
         
     | 
| 152 | 
         
            +
                    shift_x = (torch.arange(0, feat_w, device=device) +
         
     | 
| 153 | 
         
            +
                               self.offset) * stride_w
         
     | 
| 154 | 
         
            +
                    # keep featmap_size as Tensor instead of int, so that we
         
     | 
| 155 | 
         
            +
                    # can convert to ONNX correctly
         
     | 
| 156 | 
         
            +
                    shift_x = shift_x.to(dtype)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    shift_y = (torch.arange(0, feat_h, device=device) +
         
     | 
| 159 | 
         
            +
                               self.offset) * stride_h
         
     | 
| 160 | 
         
            +
                    # keep featmap_size as Tensor instead of int, so that we
         
     | 
| 161 | 
         
            +
                    # can convert to ONNX correctly
         
     | 
| 162 | 
         
            +
                    shift_y = shift_y.to(dtype)
         
     | 
| 163 | 
         
            +
                    shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
         
     | 
| 164 | 
         
            +
                    if not with_stride:
         
     | 
| 165 | 
         
            +
                        shifts = torch.stack([shift_xx, shift_yy], dim=-1)
         
     | 
| 166 | 
         
            +
                    else:
         
     | 
| 167 | 
         
            +
                        # use `shape[0]` instead of `len(shift_xx)` for ONNX export
         
     | 
| 168 | 
         
            +
                        stride_w = shift_xx.new_full((shift_xx.shape[0], ),
         
     | 
| 169 | 
         
            +
                                                     stride_w).to(dtype)
         
     | 
| 170 | 
         
            +
                        stride_h = shift_xx.new_full((shift_yy.shape[0], ),
         
     | 
| 171 | 
         
            +
                                                     stride_h).to(dtype)
         
     | 
| 172 | 
         
            +
                        shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h],
         
     | 
| 173 | 
         
            +
                                             dim=-1)
         
     | 
| 174 | 
         
            +
                    all_points = shifts.to(device)
         
     | 
| 175 | 
         
            +
                    return all_points
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
         
     | 
| 178 | 
         
            +
                    """Generate valid flags of points of multiple feature levels.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    Args:
         
     | 
| 181 | 
         
            +
                        featmap_sizes (list(tuple)): List of feature map sizes in
         
     | 
| 182 | 
         
            +
                            multiple feature levels, each size arrange as
         
     | 
| 183 | 
         
            +
                            as (h, w).
         
     | 
| 184 | 
         
            +
                        pad_shape (tuple(int)): The padded shape of the image,
         
     | 
| 185 | 
         
            +
                             arrange as (h, w).
         
     | 
| 186 | 
         
            +
                        device (str): The device where the anchors will be put on.
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    Return:
         
     | 
| 189 | 
         
            +
                        list(torch.Tensor): Valid flags of points of multiple levels.
         
     | 
| 190 | 
         
            +
                    """
         
     | 
| 191 | 
         
            +
                    assert self.num_levels == len(featmap_sizes)
         
     | 
| 192 | 
         
            +
                    multi_level_flags = []
         
     | 
| 193 | 
         
            +
                    for i in range(self.num_levels):
         
     | 
| 194 | 
         
            +
                        point_stride = self.strides[i]
         
     | 
| 195 | 
         
            +
                        feat_h, feat_w = featmap_sizes[i]
         
     | 
| 196 | 
         
            +
                        h, w = pad_shape[:2]
         
     | 
| 197 | 
         
            +
                        valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
         
     | 
| 198 | 
         
            +
                        valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
         
     | 
| 199 | 
         
            +
                        flags = self.single_level_valid_flags((feat_h, feat_w),
         
     | 
| 200 | 
         
            +
                                                              (valid_feat_h, valid_feat_w),
         
     | 
| 201 | 
         
            +
                                                              device=device)
         
     | 
| 202 | 
         
            +
                        multi_level_flags.append(flags)
         
     | 
| 203 | 
         
            +
                    return multi_level_flags
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def single_level_valid_flags(self,
         
     | 
| 206 | 
         
            +
                                             featmap_size,
         
     | 
| 207 | 
         
            +
                                             valid_size,
         
     | 
| 208 | 
         
            +
                                             device='cuda'):
         
     | 
| 209 | 
         
            +
                    """Generate the valid flags of points of a single feature map.
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    Args:
         
     | 
| 212 | 
         
            +
                        featmap_size (tuple[int]): The size of feature maps, arrange as
         
     | 
| 213 | 
         
            +
                            as (h, w).
         
     | 
| 214 | 
         
            +
                        valid_size (tuple[int]): The valid size of the feature maps.
         
     | 
| 215 | 
         
            +
                            The size arrange as as (h, w).
         
     | 
| 216 | 
         
            +
                        device (str, optional): The device where the flags will be put on.
         
     | 
| 217 | 
         
            +
                            Defaults to 'cuda'.
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    Returns:
         
     | 
| 220 | 
         
            +
                        torch.Tensor: The valid flags of each points in a single level \
         
     | 
| 221 | 
         
            +
                            feature map.
         
     | 
| 222 | 
         
            +
                    """
         
     | 
| 223 | 
         
            +
                    feat_h, feat_w = featmap_size
         
     | 
| 224 | 
         
            +
                    valid_h, valid_w = valid_size
         
     | 
| 225 | 
         
            +
                    assert valid_h <= feat_h and valid_w <= feat_w
         
     | 
| 226 | 
         
            +
                    valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
         
     | 
| 227 | 
         
            +
                    valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
         
     | 
| 228 | 
         
            +
                    valid_x[:valid_w] = 1
         
     | 
| 229 | 
         
            +
                    valid_y[:valid_h] = 1
         
     | 
| 230 | 
         
            +
                    valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
         
     | 
| 231 | 
         
            +
                    valid = valid_xx & valid_yy
         
     | 
| 232 | 
         
            +
                    return valid
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                def sparse_priors(self,
         
     | 
| 235 | 
         
            +
                                  prior_idxs,
         
     | 
| 236 | 
         
            +
                                  featmap_size,
         
     | 
| 237 | 
         
            +
                                  level_idx,
         
     | 
| 238 | 
         
            +
                                  dtype=torch.float32,
         
     | 
| 239 | 
         
            +
                                  device='cuda'):
         
     | 
| 240 | 
         
            +
                    """Generate sparse points according to the ``prior_idxs``.
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    Args:
         
     | 
| 243 | 
         
            +
                        prior_idxs (Tensor): The index of corresponding anchors
         
     | 
| 244 | 
         
            +
                            in the feature map.
         
     | 
| 245 | 
         
            +
                        featmap_size (tuple[int]): feature map size arrange as (w, h).
         
     | 
| 246 | 
         
            +
                        level_idx (int): The level index of corresponding feature
         
     | 
| 247 | 
         
            +
                            map.
         
     | 
| 248 | 
         
            +
                        dtype (obj:`torch.dtype`): Date type of points. Defaults to
         
     | 
| 249 | 
         
            +
                            ``torch.float32``.
         
     | 
| 250 | 
         
            +
                        device (obj:`torch.device`): The device where the points is
         
     | 
| 251 | 
         
            +
                            located.
         
     | 
| 252 | 
         
            +
                    Returns:
         
     | 
| 253 | 
         
            +
                        Tensor: Anchor with shape (N, 2), N should be equal to
         
     | 
| 254 | 
         
            +
                        the length of ``prior_idxs``. And last dimension
         
     | 
| 255 | 
         
            +
                        2 represent (coord_x, coord_y).
         
     | 
| 256 | 
         
            +
                    """
         
     | 
| 257 | 
         
            +
                    height, width = featmap_size
         
     | 
| 258 | 
         
            +
                    x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
         
     | 
| 259 | 
         
            +
                    y = ((prior_idxs // width) % height +
         
     | 
| 260 | 
         
            +
                         self.offset) * self.strides[level_idx][1]
         
     | 
| 261 | 
         
            +
                    prioris = torch.stack([x, y], 1).to(dtype)
         
     | 
| 262 | 
         
            +
                    prioris = prioris.to(device)
         
     | 
| 263 | 
         
            +
                    return prioris
         
     | 
    	
        mmdet/core/anchor/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,72 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def images_to_levels(target, num_levels):
         
     | 
| 6 | 
         
            +
                """Convert targets by image to targets by feature level.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                [target_img0, target_img1] -> [target_level0, target_level1, ...]
         
     | 
| 9 | 
         
            +
                """
         
     | 
| 10 | 
         
            +
                target = torch.stack(target, 0)
         
     | 
| 11 | 
         
            +
                level_targets = []
         
     | 
| 12 | 
         
            +
                start = 0
         
     | 
| 13 | 
         
            +
                for n in num_levels:
         
     | 
| 14 | 
         
            +
                    end = start + n
         
     | 
| 15 | 
         
            +
                    # level_targets.append(target[:, start:end].squeeze(0))
         
     | 
| 16 | 
         
            +
                    level_targets.append(target[:, start:end])
         
     | 
| 17 | 
         
            +
                    start = end
         
     | 
| 18 | 
         
            +
                return level_targets
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def anchor_inside_flags(flat_anchors,
         
     | 
| 22 | 
         
            +
                                    valid_flags,
         
     | 
| 23 | 
         
            +
                                    img_shape,
         
     | 
| 24 | 
         
            +
                                    allowed_border=0):
         
     | 
| 25 | 
         
            +
                """Check whether the anchors are inside the border.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Args:
         
     | 
| 28 | 
         
            +
                    flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
         
     | 
| 29 | 
         
            +
                    valid_flags (torch.Tensor): An existing valid flags of anchors.
         
     | 
| 30 | 
         
            +
                    img_shape (tuple(int)): Shape of current image.
         
     | 
| 31 | 
         
            +
                    allowed_border (int, optional): The border to allow the valid anchor.
         
     | 
| 32 | 
         
            +
                        Defaults to 0.
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                Returns:
         
     | 
| 35 | 
         
            +
                    torch.Tensor: Flags indicating whether the anchors are inside a \
         
     | 
| 36 | 
         
            +
                        valid range.
         
     | 
| 37 | 
         
            +
                """
         
     | 
| 38 | 
         
            +
                img_h, img_w = img_shape[:2]
         
     | 
| 39 | 
         
            +
                if allowed_border >= 0:
         
     | 
| 40 | 
         
            +
                    inside_flags = valid_flags & \
         
     | 
| 41 | 
         
            +
                        (flat_anchors[:, 0] >= -allowed_border) & \
         
     | 
| 42 | 
         
            +
                        (flat_anchors[:, 1] >= -allowed_border) & \
         
     | 
| 43 | 
         
            +
                        (flat_anchors[:, 2] < img_w + allowed_border) & \
         
     | 
| 44 | 
         
            +
                        (flat_anchors[:, 3] < img_h + allowed_border)
         
     | 
| 45 | 
         
            +
                else:
         
     | 
| 46 | 
         
            +
                    inside_flags = valid_flags
         
     | 
| 47 | 
         
            +
                return inside_flags
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def calc_region(bbox, ratio, featmap_size=None):
         
     | 
| 51 | 
         
            +
                """Calculate a proportional bbox region.
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                Args:
         
     | 
| 56 | 
         
            +
                    bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
         
     | 
| 57 | 
         
            +
                    ratio (float): Ratio of the output region.
         
     | 
| 58 | 
         
            +
                    featmap_size (tuple): Feature map size used for clipping the boundary.
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                Returns:
         
     | 
| 61 | 
         
            +
                    tuple: x1, y1, x2, y2
         
     | 
| 62 | 
         
            +
                """
         
     | 
| 63 | 
         
            +
                x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
         
     | 
| 64 | 
         
            +
                y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
         
     | 
| 65 | 
         
            +
                x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
         
     | 
| 66 | 
         
            +
                y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
         
     | 
| 67 | 
         
            +
                if featmap_size is not None:
         
     | 
| 68 | 
         
            +
                    x1 = x1.clamp(min=0, max=featmap_size[1])
         
     | 
| 69 | 
         
            +
                    y1 = y1.clamp(min=0, max=featmap_size[0])
         
     | 
| 70 | 
         
            +
                    x2 = x2.clamp(min=0, max=featmap_size[1])
         
     | 
| 71 | 
         
            +
                    y2 = y2.clamp(min=0, max=featmap_size[0])
         
     | 
| 72 | 
         
            +
                return (x1, y1, x2, y2)
         
     | 
    	
        mmdet/core/bbox/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,28 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
         
     | 
| 3 | 
         
            +
                                    MaxIoUAssigner, RegionAssigner)
         
     | 
| 4 | 
         
            +
            from .builder import build_assigner, build_bbox_coder, build_sampler
         
     | 
| 5 | 
         
            +
            from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, DistancePointBBoxCoder,
         
     | 
| 6 | 
         
            +
                                PseudoBBoxCoder, TBLRBBoxCoder)
         
     | 
| 7 | 
         
            +
            from .iou_calculators import BboxOverlaps2D, bbox_overlaps
         
     | 
| 8 | 
         
            +
            from .samplers import (BaseSampler, CombinedSampler,
         
     | 
| 9 | 
         
            +
                                   InstanceBalancedPosSampler, IoUBalancedNegSampler,
         
     | 
| 10 | 
         
            +
                                   OHEMSampler, PseudoSampler, RandomSampler,
         
     | 
| 11 | 
         
            +
                                   SamplingResult, ScoreHLRSampler)
         
     | 
| 12 | 
         
            +
            from .transforms import (bbox2distance, bbox2result, bbox2roi,
         
     | 
| 13 | 
         
            +
                                     bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
         
     | 
| 14 | 
         
            +
                                     bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
         
     | 
| 15 | 
         
            +
                                     distance2bbox, find_inside_bboxes, roi2bbox)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            __all__ = [
         
     | 
| 18 | 
         
            +
                'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
         
     | 
| 19 | 
         
            +
                'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
         
     | 
| 20 | 
         
            +
                'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
         
     | 
| 21 | 
         
            +
                'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
         
     | 
| 22 | 
         
            +
                'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
         
     | 
| 23 | 
         
            +
                'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
         
     | 
| 24 | 
         
            +
                'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
         
     | 
| 25 | 
         
            +
                'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder',
         
     | 
| 26 | 
         
            +
                'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy',
         
     | 
| 27 | 
         
            +
                'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'find_inside_bboxes'
         
     | 
| 28 | 
         
            +
            ]
         
     | 
    	
        mmdet/core/bbox/assigners/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from .approx_max_iou_assigner import ApproxMaxIoUAssigner
         
     | 
| 3 | 
         
            +
            from .ascend_assign_result import AscendAssignResult
         
     | 
| 4 | 
         
            +
            from .ascend_max_iou_assigner import AscendMaxIoUAssigner
         
     | 
| 5 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 6 | 
         
            +
            from .atss_assigner import ATSSAssigner
         
     | 
| 7 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 8 | 
         
            +
            from .center_region_assigner import CenterRegionAssigner
         
     | 
| 9 | 
         
            +
            from .grid_assigner import GridAssigner
         
     | 
| 10 | 
         
            +
            from .hungarian_assigner import HungarianAssigner
         
     | 
| 11 | 
         
            +
            from .mask_hungarian_assigner import MaskHungarianAssigner
         
     | 
| 12 | 
         
            +
            from .max_iou_assigner import MaxIoUAssigner
         
     | 
| 13 | 
         
            +
            from .point_assigner import PointAssigner
         
     | 
| 14 | 
         
            +
            from .region_assigner import RegionAssigner
         
     | 
| 15 | 
         
            +
            from .sim_ota_assigner import SimOTAAssigner
         
     | 
| 16 | 
         
            +
            from .task_aligned_assigner import TaskAlignedAssigner
         
     | 
| 17 | 
         
            +
            from .uniform_assigner import UniformAssigner
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            __all__ = [
         
     | 
| 20 | 
         
            +
                'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
         
     | 
| 21 | 
         
            +
                'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
         
     | 
| 22 | 
         
            +
                'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner',
         
     | 
| 23 | 
         
            +
                'TaskAlignedAssigner', 'MaskHungarianAssigner', 'AscendAssignResult',
         
     | 
| 24 | 
         
            +
                'AscendMaxIoUAssigner'
         
     | 
| 25 | 
         
            +
            ]
         
     | 
    	
        mmdet/core/bbox/assigners/approx_max_iou_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,146 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 6 | 
         
            +
            from .max_iou_assigner import MaxIoUAssigner
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 10 | 
         
            +
            class ApproxMaxIoUAssigner(MaxIoUAssigner):
         
     | 
| 11 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Each proposals will be assigned with an integer indicating the ground truth
         
     | 
| 14 | 
         
            +
                 index. (semi-positive index: gt label (0-based), -1: background)
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                - -1: negative sample, no assigned gt
         
     | 
| 17 | 
         
            +
                - semi-positive integer: positive sample, index (0-based) of assigned gt
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                Args:
         
     | 
| 20 | 
         
            +
                    pos_iou_thr (float): IoU threshold for positive bboxes.
         
     | 
| 21 | 
         
            +
                    neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
         
     | 
| 22 | 
         
            +
                    min_pos_iou (float): Minimum iou for a bbox to be considered as a
         
     | 
| 23 | 
         
            +
                        positive bbox. Positive samples can have smaller IoU than
         
     | 
| 24 | 
         
            +
                        pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
         
     | 
| 25 | 
         
            +
                    gt_max_assign_all (bool): Whether to assign all bboxes with the same
         
     | 
| 26 | 
         
            +
                        highest overlap with some gt to that gt.
         
     | 
| 27 | 
         
            +
                    ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
         
     | 
| 28 | 
         
            +
                        `gt_bboxes_ignore` is specified). Negative values mean not
         
     | 
| 29 | 
         
            +
                        ignoring any bboxes.
         
     | 
| 30 | 
         
            +
                    ignore_wrt_candidates (bool): Whether to compute the iof between
         
     | 
| 31 | 
         
            +
                        `bboxes` and `gt_bboxes_ignore`, or the contrary.
         
     | 
| 32 | 
         
            +
                    match_low_quality (bool): Whether to allow quality matches. This is
         
     | 
| 33 | 
         
            +
                        usually allowed for RPN and single stage detectors, but not allowed
         
     | 
| 34 | 
         
            +
                        in the second stage.
         
     | 
| 35 | 
         
            +
                    gpu_assign_thr (int): The upper bound of the number of GT for GPU
         
     | 
| 36 | 
         
            +
                        assign. When the number of gt is above this threshold, will assign
         
     | 
| 37 | 
         
            +
                        on CPU device. Negative values mean not assign on CPU.
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __init__(self,
         
     | 
| 41 | 
         
            +
                             pos_iou_thr,
         
     | 
| 42 | 
         
            +
                             neg_iou_thr,
         
     | 
| 43 | 
         
            +
                             min_pos_iou=.0,
         
     | 
| 44 | 
         
            +
                             gt_max_assign_all=True,
         
     | 
| 45 | 
         
            +
                             ignore_iof_thr=-1,
         
     | 
| 46 | 
         
            +
                             ignore_wrt_candidates=True,
         
     | 
| 47 | 
         
            +
                             match_low_quality=True,
         
     | 
| 48 | 
         
            +
                             gpu_assign_thr=-1,
         
     | 
| 49 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 50 | 
         
            +
                    self.pos_iou_thr = pos_iou_thr
         
     | 
| 51 | 
         
            +
                    self.neg_iou_thr = neg_iou_thr
         
     | 
| 52 | 
         
            +
                    self.min_pos_iou = min_pos_iou
         
     | 
| 53 | 
         
            +
                    self.gt_max_assign_all = gt_max_assign_all
         
     | 
| 54 | 
         
            +
                    self.ignore_iof_thr = ignore_iof_thr
         
     | 
| 55 | 
         
            +
                    self.ignore_wrt_candidates = ignore_wrt_candidates
         
     | 
| 56 | 
         
            +
                    self.gpu_assign_thr = gpu_assign_thr
         
     | 
| 57 | 
         
            +
                    self.match_low_quality = match_low_quality
         
     | 
| 58 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def assign(self,
         
     | 
| 61 | 
         
            +
                           approxs,
         
     | 
| 62 | 
         
            +
                           squares,
         
     | 
| 63 | 
         
            +
                           approxs_per_octave,
         
     | 
| 64 | 
         
            +
                           gt_bboxes,
         
     | 
| 65 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 66 | 
         
            +
                           gt_labels=None):
         
     | 
| 67 | 
         
            +
                    """Assign gt to approxs.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    This method assign a gt bbox to each group of approxs (bboxes),
         
     | 
| 70 | 
         
            +
                    each group of approxs is represent by a base approx (bbox) and
         
     | 
| 71 | 
         
            +
                    will be assigned with -1, or a semi-positive number.
         
     | 
| 72 | 
         
            +
                    background_label (-1) means negative sample,
         
     | 
| 73 | 
         
            +
                    semi-positive number is the index (0-based) of assigned gt.
         
     | 
| 74 | 
         
            +
                    The assignment is done in following steps, the order matters.
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    1. assign every bbox to background_label (-1)
         
     | 
| 77 | 
         
            +
                    2. use the max IoU of each group of approxs to assign
         
     | 
| 78 | 
         
            +
                    2. assign proposals whose iou with all gts < neg_iou_thr to background
         
     | 
| 79 | 
         
            +
                    3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
         
     | 
| 80 | 
         
            +
                       assign it to that bbox
         
     | 
| 81 | 
         
            +
                    4. for each gt bbox, assign its nearest proposals (may be more than
         
     | 
| 82 | 
         
            +
                       one) to itself
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    Args:
         
     | 
| 85 | 
         
            +
                        approxs (Tensor): Bounding boxes to be assigned,
         
     | 
| 86 | 
         
            +
                            shape(approxs_per_octave*n, 4).
         
     | 
| 87 | 
         
            +
                        squares (Tensor): Base Bounding boxes to be assigned,
         
     | 
| 88 | 
         
            +
                            shape(n, 4).
         
     | 
| 89 | 
         
            +
                        approxs_per_octave (int): number of approxs per octave
         
     | 
| 90 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 91 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 92 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 93 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    Returns:
         
     | 
| 96 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 97 | 
         
            +
                    """
         
     | 
| 98 | 
         
            +
                    num_squares = squares.size(0)
         
     | 
| 99 | 
         
            +
                    num_gts = gt_bboxes.size(0)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    if num_squares == 0 or num_gts == 0:
         
     | 
| 102 | 
         
            +
                        # No predictions and/or truth, return empty assignment
         
     | 
| 103 | 
         
            +
                        overlaps = approxs.new(num_gts, num_squares)
         
     | 
| 104 | 
         
            +
                        assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
         
     | 
| 105 | 
         
            +
                        return assign_result
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    # re-organize anchors by approxs_per_octave x num_squares
         
     | 
| 108 | 
         
            +
                    approxs = torch.transpose(
         
     | 
| 109 | 
         
            +
                        approxs.view(num_squares, approxs_per_octave, 4), 0,
         
     | 
| 110 | 
         
            +
                        1).contiguous().view(-1, 4)
         
     | 
| 111 | 
         
            +
                    assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
         
     | 
| 112 | 
         
            +
                        num_gts > self.gpu_assign_thr) else False
         
     | 
| 113 | 
         
            +
                    # compute overlap and assign gt on CPU when number of GT is large
         
     | 
| 114 | 
         
            +
                    if assign_on_cpu:
         
     | 
| 115 | 
         
            +
                        device = approxs.device
         
     | 
| 116 | 
         
            +
                        approxs = approxs.cpu()
         
     | 
| 117 | 
         
            +
                        gt_bboxes = gt_bboxes.cpu()
         
     | 
| 118 | 
         
            +
                        if gt_bboxes_ignore is not None:
         
     | 
| 119 | 
         
            +
                            gt_bboxes_ignore = gt_bboxes_ignore.cpu()
         
     | 
| 120 | 
         
            +
                        if gt_labels is not None:
         
     | 
| 121 | 
         
            +
                            gt_labels = gt_labels.cpu()
         
     | 
| 122 | 
         
            +
                    all_overlaps = self.iou_calculator(approxs, gt_bboxes)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
         
     | 
| 125 | 
         
            +
                                                    num_gts).max(dim=0)
         
     | 
| 126 | 
         
            +
                    overlaps = torch.transpose(overlaps, 0, 1)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
         
     | 
| 129 | 
         
            +
                            and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
         
     | 
| 130 | 
         
            +
                        if self.ignore_wrt_candidates:
         
     | 
| 131 | 
         
            +
                            ignore_overlaps = self.iou_calculator(
         
     | 
| 132 | 
         
            +
                                squares, gt_bboxes_ignore, mode='iof')
         
     | 
| 133 | 
         
            +
                            ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
         
     | 
| 134 | 
         
            +
                        else:
         
     | 
| 135 | 
         
            +
                            ignore_overlaps = self.iou_calculator(
         
     | 
| 136 | 
         
            +
                                gt_bboxes_ignore, squares, mode='iof')
         
     | 
| 137 | 
         
            +
                            ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
         
     | 
| 138 | 
         
            +
                        overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
         
     | 
| 141 | 
         
            +
                    if assign_on_cpu:
         
     | 
| 142 | 
         
            +
                        assign_result.gt_inds = assign_result.gt_inds.to(device)
         
     | 
| 143 | 
         
            +
                        assign_result.max_overlaps = assign_result.max_overlaps.to(device)
         
     | 
| 144 | 
         
            +
                        if assign_result.labels is not None:
         
     | 
| 145 | 
         
            +
                            assign_result.labels = assign_result.labels.to(device)
         
     | 
| 146 | 
         
            +
                    return assign_result
         
     | 
    	
        mmdet/core/bbox/assigners/ascend_assign_result.py
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from mmdet.utils import util_mixins
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class AscendAssignResult(util_mixins.NiceRepr):
         
     | 
| 6 | 
         
            +
                """Stores ascend assignments between predicted and truth boxes.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                Arguments:
         
     | 
| 9 | 
         
            +
                    batch_num_gts (list[int]): the number of truth boxes considered.
         
     | 
| 10 | 
         
            +
                    batch_pos_mask (IntTensor): Positive samples mask in all images.
         
     | 
| 11 | 
         
            +
                    batch_neg_mask (IntTensor): Negative samples mask in all images.
         
     | 
| 12 | 
         
            +
                    batch_max_overlaps (FloatTensor): The max overlaps of all bboxes
         
     | 
| 13 | 
         
            +
                        and ground truth boxes.
         
     | 
| 14 | 
         
            +
                    batch_anchor_gt_indes(None | LongTensor): The assigned truth
         
     | 
| 15 | 
         
            +
                        box index of all anchors.
         
     | 
| 16 | 
         
            +
                    batch_anchor_gt_labels(None | LongTensor): The gt labels
         
     | 
| 17 | 
         
            +
                        of all anchors
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __init__(self,
         
     | 
| 21 | 
         
            +
                             batch_num_gts,
         
     | 
| 22 | 
         
            +
                             batch_pos_mask,
         
     | 
| 23 | 
         
            +
                             batch_neg_mask,
         
     | 
| 24 | 
         
            +
                             batch_max_overlaps,
         
     | 
| 25 | 
         
            +
                             batch_anchor_gt_indes=None,
         
     | 
| 26 | 
         
            +
                             batch_anchor_gt_labels=None):
         
     | 
| 27 | 
         
            +
                    self.batch_num_gts = batch_num_gts
         
     | 
| 28 | 
         
            +
                    self.batch_pos_mask = batch_pos_mask
         
     | 
| 29 | 
         
            +
                    self.batch_neg_mask = batch_neg_mask
         
     | 
| 30 | 
         
            +
                    self.batch_max_overlaps = batch_max_overlaps
         
     | 
| 31 | 
         
            +
                    self.batch_anchor_gt_indes = batch_anchor_gt_indes
         
     | 
| 32 | 
         
            +
                    self.batch_anchor_gt_labels = batch_anchor_gt_labels
         
     | 
| 33 | 
         
            +
                    # Interface for possible user-defined properties
         
     | 
| 34 | 
         
            +
                    self._extra_properties = {}
         
     | 
    	
        mmdet/core/bbox/assigners/ascend_max_iou_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,178 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ....utils import masked_fill
         
     | 
| 5 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 6 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 7 | 
         
            +
            from .ascend_assign_result import AscendAssignResult
         
     | 
| 8 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 12 | 
         
            +
            class AscendMaxIoUAssigner(BaseAssigner):
         
     | 
| 13 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                Each proposals will be assigned with `-1`, or a semi-positive integer
         
     | 
| 16 | 
         
            +
                indicating the ground truth index.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                - -1: negative sample, no assigned gt
         
     | 
| 19 | 
         
            +
                - semi-positive integer: positive sample, index (0-based) of assigned gt
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    pos_iou_thr (float): IoU threshold for positive bboxes.
         
     | 
| 23 | 
         
            +
                    neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
         
     | 
| 24 | 
         
            +
                    min_pos_iou (float): Minimum iou for a bbox to be considered as a
         
     | 
| 25 | 
         
            +
                        positive bbox. Positive samples can have smaller IoU than
         
     | 
| 26 | 
         
            +
                        pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
         
     | 
| 27 | 
         
            +
                        `min_pos_iou` is set to avoid assigning bboxes that have extremely
         
     | 
| 28 | 
         
            +
                        small iou with GT as positive samples. It brings about 0.3 mAP
         
     | 
| 29 | 
         
            +
                        improvements in 1x schedule but does not affect the performance of
         
     | 
| 30 | 
         
            +
                        3x schedule. More comparisons can be found in
         
     | 
| 31 | 
         
            +
                        `PR #7464 <https://github.com/open-mmlab/mmdetection/pull/7464>`_.
         
     | 
| 32 | 
         
            +
                    gt_max_assign_all (bool): Whether to assign all bboxes with the same
         
     | 
| 33 | 
         
            +
                        highest overlap with some gt to that gt.
         
     | 
| 34 | 
         
            +
                    ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
         
     | 
| 35 | 
         
            +
                        `gt_bboxes_ignore` is specified). Negative values mean not
         
     | 
| 36 | 
         
            +
                        ignoring any bboxes.
         
     | 
| 37 | 
         
            +
                    ignore_wrt_candidates (bool): Whether to compute the iof between
         
     | 
| 38 | 
         
            +
                        `bboxes` and `gt_bboxes_ignore`, or the contrary.
         
     | 
| 39 | 
         
            +
                    match_low_quality (bool): Whether to allow low quality matches. This is
         
     | 
| 40 | 
         
            +
                        usually allowed for RPN and single stage detectors, but not allowed
         
     | 
| 41 | 
         
            +
                        in the second stage. Details are demonstrated in Step 4.
         
     | 
| 42 | 
         
            +
                    gpu_assign_thr (int): The upper bound of the number of GT for GPU
         
     | 
| 43 | 
         
            +
                        assign. When the number of gt is above this threshold, will assign
         
     | 
| 44 | 
         
            +
                        on CPU device. Negative values mean not assign on CPU.
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def __init__(self,
         
     | 
| 48 | 
         
            +
                             pos_iou_thr,
         
     | 
| 49 | 
         
            +
                             neg_iou_thr,
         
     | 
| 50 | 
         
            +
                             min_pos_iou=.0,
         
     | 
| 51 | 
         
            +
                             gt_max_assign_all=True,
         
     | 
| 52 | 
         
            +
                             ignore_iof_thr=-1,
         
     | 
| 53 | 
         
            +
                             ignore_wrt_candidates=True,
         
     | 
| 54 | 
         
            +
                             match_low_quality=True,
         
     | 
| 55 | 
         
            +
                             gpu_assign_thr=-1,
         
     | 
| 56 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 57 | 
         
            +
                    self.pos_iou_thr = pos_iou_thr
         
     | 
| 58 | 
         
            +
                    self.neg_iou_thr = neg_iou_thr
         
     | 
| 59 | 
         
            +
                    self.min_pos_iou = min_pos_iou
         
     | 
| 60 | 
         
            +
                    self.gt_max_assign_all = gt_max_assign_all
         
     | 
| 61 | 
         
            +
                    self.ignore_iof_thr = ignore_iof_thr
         
     | 
| 62 | 
         
            +
                    self.ignore_wrt_candidates = ignore_wrt_candidates
         
     | 
| 63 | 
         
            +
                    self.gpu_assign_thr = gpu_assign_thr
         
     | 
| 64 | 
         
            +
                    self.match_low_quality = match_low_quality
         
     | 
| 65 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def assign(self,
         
     | 
| 68 | 
         
            +
                           batch_bboxes,
         
     | 
| 69 | 
         
            +
                           batch_gt_bboxes,
         
     | 
| 70 | 
         
            +
                           batch_gt_bboxes_ignore=None,
         
     | 
| 71 | 
         
            +
                           batch_gt_labels=None,
         
     | 
| 72 | 
         
            +
                           batch_bboxes_ignore_mask=None,
         
     | 
| 73 | 
         
            +
                           batch_num_gts=None):
         
     | 
| 74 | 
         
            +
                    """Assign gt to bboxes.
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    Args:
         
     | 
| 77 | 
         
            +
                        batch_bboxes (Tensor): Bounding boxes to be assigned,
         
     | 
| 78 | 
         
            +
                            shape(b, n, 4).
         
     | 
| 79 | 
         
            +
                        batch_gt_bboxes (Tensor): Ground truth boxes,
         
     | 
| 80 | 
         
            +
                            shape (b, k, 4).
         
     | 
| 81 | 
         
            +
                        batch_gt_bboxes_ignore (Tensor, optional): Ground truth
         
     | 
| 82 | 
         
            +
                            bboxes that are labelled as `ignored`,
         
     | 
| 83 | 
         
            +
                            e.g., crowd boxes in COCO.
         
     | 
| 84 | 
         
            +
                        batch_gt_labels (Tensor, optional): Label of gt_bboxes,
         
     | 
| 85 | 
         
            +
                            shape (b, k, ).
         
     | 
| 86 | 
         
            +
                        batch_bboxes_ignore_mask: (b, n)
         
     | 
| 87 | 
         
            +
                        batch_num_gts:(b, )
         
     | 
| 88 | 
         
            +
                    Returns:
         
     | 
| 89 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 90 | 
         
            +
                    """
         
     | 
| 91 | 
         
            +
                    batch_overlaps = self.iou_calculator(batch_gt_bboxes, batch_bboxes)
         
     | 
| 92 | 
         
            +
                    batch_overlaps = masked_fill(
         
     | 
| 93 | 
         
            +
                        batch_overlaps,
         
     | 
| 94 | 
         
            +
                        batch_bboxes_ignore_mask.unsqueeze(1).float(),
         
     | 
| 95 | 
         
            +
                        -1,
         
     | 
| 96 | 
         
            +
                        neg=True)
         
     | 
| 97 | 
         
            +
                    if self.ignore_iof_thr > 0 and batch_gt_bboxes_ignore is not None:
         
     | 
| 98 | 
         
            +
                        if self.ignore_wrt_candidates:
         
     | 
| 99 | 
         
            +
                            batch_ignore_overlaps = self.iou_calculator(
         
     | 
| 100 | 
         
            +
                                batch_bboxes, batch_gt_bboxes_ignore, mode='iof')
         
     | 
| 101 | 
         
            +
                            batch_ignore_overlaps = masked_fill(batch_ignore_overlaps,
         
     | 
| 102 | 
         
            +
                                                                batch_bboxes_ignore_mask,
         
     | 
| 103 | 
         
            +
                                                                -1)
         
     | 
| 104 | 
         
            +
                            batch_ignore_max_overlaps, _ = batch_ignore_overlaps.max(dim=2)
         
     | 
| 105 | 
         
            +
                        else:
         
     | 
| 106 | 
         
            +
                            batch_ignore_overlaps = self.iou_calculator(
         
     | 
| 107 | 
         
            +
                                batch_gt_bboxes_ignore, batch_bboxes, mode='iof')
         
     | 
| 108 | 
         
            +
                            batch_ignore_overlaps = masked_fill(batch_ignore_overlaps,
         
     | 
| 109 | 
         
            +
                                                                batch_bboxes_ignore_mask,
         
     | 
| 110 | 
         
            +
                                                                -1)
         
     | 
| 111 | 
         
            +
                            batch_ignore_max_overlaps, _ = \
         
     | 
| 112 | 
         
            +
                                batch_ignore_overlaps.max(dim=1)
         
     | 
| 113 | 
         
            +
                        batch_ignore_mask = \
         
     | 
| 114 | 
         
            +
                            batch_ignore_max_overlaps > self.ignore_iof_thr
         
     | 
| 115 | 
         
            +
                        batch_overlaps = masked_fill(batch_overlaps, batch_ignore_mask, -1)
         
     | 
| 116 | 
         
            +
                    batch_assign_result = self.batch_assign_wrt_overlaps(
         
     | 
| 117 | 
         
            +
                        batch_overlaps, batch_gt_labels, batch_num_gts)
         
     | 
| 118 | 
         
            +
                    return batch_assign_result
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                def batch_assign_wrt_overlaps(self,
         
     | 
| 121 | 
         
            +
                                              batch_overlaps,
         
     | 
| 122 | 
         
            +
                                              batch_gt_labels=None,
         
     | 
| 123 | 
         
            +
                                              batch_num_gts=None):
         
     | 
| 124 | 
         
            +
                    num_images, num_gts, num_bboxes = batch_overlaps.size()
         
     | 
| 125 | 
         
            +
                    batch_max_overlaps, batch_argmax_overlaps = batch_overlaps.max(dim=1)
         
     | 
| 126 | 
         
            +
                    if isinstance(self.neg_iou_thr, float):
         
     | 
| 127 | 
         
            +
                        batch_neg_mask = \
         
     | 
| 128 | 
         
            +
                            ((batch_max_overlaps >= 0)
         
     | 
| 129 | 
         
            +
                             & (batch_max_overlaps < self.neg_iou_thr)).int()
         
     | 
| 130 | 
         
            +
                    elif isinstance(self.neg_iou_thr, tuple):
         
     | 
| 131 | 
         
            +
                        assert len(self.neg_iou_thr) == 2
         
     | 
| 132 | 
         
            +
                        batch_neg_mask = \
         
     | 
| 133 | 
         
            +
                            ((batch_max_overlaps >= self.neg_iou_thr[0])
         
     | 
| 134 | 
         
            +
                             & (batch_max_overlaps < self.neg_iou_thr[1])).int()
         
     | 
| 135 | 
         
            +
                    else:
         
     | 
| 136 | 
         
            +
                        batch_neg_mask = torch.zeros(
         
     | 
| 137 | 
         
            +
                            batch_max_overlaps.size(),
         
     | 
| 138 | 
         
            +
                            dtype=torch.int,
         
     | 
| 139 | 
         
            +
                            device=batch_max_overlaps.device)
         
     | 
| 140 | 
         
            +
                    batch_pos_mask = (batch_max_overlaps >= self.pos_iou_thr).int()
         
     | 
| 141 | 
         
            +
                    if self.match_low_quality:
         
     | 
| 142 | 
         
            +
                        batch_gt_max_overlaps, batch_gt_argmax_overlaps = \
         
     | 
| 143 | 
         
            +
                            batch_overlaps.max(dim=2)
         
     | 
| 144 | 
         
            +
                        batch_index_bool = (batch_gt_max_overlaps >= self.min_pos_iou) & \
         
     | 
| 145 | 
         
            +
                                           (batch_gt_max_overlaps > 0)
         
     | 
| 146 | 
         
            +
                        if self.gt_max_assign_all:
         
     | 
| 147 | 
         
            +
                            pos_inds_low_quality = \
         
     | 
| 148 | 
         
            +
                                (batch_overlaps == batch_gt_max_overlaps.unsqueeze(2)) & \
         
     | 
| 149 | 
         
            +
                                batch_index_bool.unsqueeze(2)
         
     | 
| 150 | 
         
            +
                            for i in range(num_gts):
         
     | 
| 151 | 
         
            +
                                pos_inds_low_quality_gt = pos_inds_low_quality[:, i, :]
         
     | 
| 152 | 
         
            +
                                batch_argmax_overlaps[pos_inds_low_quality_gt] = i
         
     | 
| 153 | 
         
            +
                                batch_pos_mask[pos_inds_low_quality_gt] = 1
         
     | 
| 154 | 
         
            +
                        else:
         
     | 
| 155 | 
         
            +
                            index_temp = torch.arange(
         
     | 
| 156 | 
         
            +
                                0, num_gts, device=batch_max_overlaps.device)
         
     | 
| 157 | 
         
            +
                            for index_image in range(num_images):
         
     | 
| 158 | 
         
            +
                                gt_argmax_overlaps = batch_gt_argmax_overlaps[index_image]
         
     | 
| 159 | 
         
            +
                                index_bool = batch_index_bool[index_image]
         
     | 
| 160 | 
         
            +
                                pos_inds_low_quality = gt_argmax_overlaps[index_bool]
         
     | 
| 161 | 
         
            +
                                batch_argmax_overlaps[index_image][pos_inds_low_quality] \
         
     | 
| 162 | 
         
            +
                                    = index_temp[index_bool]
         
     | 
| 163 | 
         
            +
                                batch_pos_mask[index_image][pos_inds_low_quality] = 1
         
     | 
| 164 | 
         
            +
                    batch_neg_mask = batch_neg_mask * (1 - batch_pos_mask)
         
     | 
| 165 | 
         
            +
                    if batch_gt_labels is not None:
         
     | 
| 166 | 
         
            +
                        batch_anchor_gt_labels = torch.zeros((num_images, num_bboxes),
         
     | 
| 167 | 
         
            +
                                                             dtype=batch_gt_labels.dtype,
         
     | 
| 168 | 
         
            +
                                                             device=batch_gt_labels.device)
         
     | 
| 169 | 
         
            +
                        for index_image in range(num_images):
         
     | 
| 170 | 
         
            +
                            batch_anchor_gt_labels[index_image] = torch.index_select(
         
     | 
| 171 | 
         
            +
                                batch_gt_labels[index_image], 0,
         
     | 
| 172 | 
         
            +
                                batch_argmax_overlaps[index_image])
         
     | 
| 173 | 
         
            +
                    else:
         
     | 
| 174 | 
         
            +
                        batch_anchor_gt_labels = None
         
     | 
| 175 | 
         
            +
                    return AscendAssignResult(batch_num_gts, batch_pos_mask,
         
     | 
| 176 | 
         
            +
                                              batch_neg_mask, batch_max_overlaps,
         
     | 
| 177 | 
         
            +
                                              batch_argmax_overlaps,
         
     | 
| 178 | 
         
            +
                                              batch_anchor_gt_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/assign_result.py
    ADDED
    
    | 
         @@ -0,0 +1,206 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from mmdet.utils import util_mixins
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class AssignResult(util_mixins.NiceRepr):
         
     | 
| 8 | 
         
            +
                """Stores assignments between predicted and truth boxes.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                Attributes:
         
     | 
| 11 | 
         
            +
                    num_gts (int): the number of truth boxes considered when computing this
         
     | 
| 12 | 
         
            +
                        assignment
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                    gt_inds (LongTensor): for each predicted box indicates the 1-based
         
     | 
| 15 | 
         
            +
                        index of the assigned truth box. 0 means unassigned and -1 means
         
     | 
| 16 | 
         
            +
                        ignore.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    max_overlaps (FloatTensor): the iou between the predicted box and its
         
     | 
| 19 | 
         
            +
                        assigned truth box.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    labels (None | LongTensor): If specified, for each predicted box
         
     | 
| 22 | 
         
            +
                        indicates the category label of the assigned truth box.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                Example:
         
     | 
| 25 | 
         
            +
                    >>> # An assign result between 4 predicted boxes and 9 true boxes
         
     | 
| 26 | 
         
            +
                    >>> # where only two boxes were assigned.
         
     | 
| 27 | 
         
            +
                    >>> num_gts = 9
         
     | 
| 28 | 
         
            +
                    >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
         
     | 
| 29 | 
         
            +
                    >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
         
     | 
| 30 | 
         
            +
                    >>> labels = torch.LongTensor([0, 3, 4, 0])
         
     | 
| 31 | 
         
            +
                    >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
         
     | 
| 32 | 
         
            +
                    >>> print(str(self))  # xdoctest: +IGNORE_WANT
         
     | 
| 33 | 
         
            +
                    <AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
         
     | 
| 34 | 
         
            +
                                  labels.shape=(4,))>
         
     | 
| 35 | 
         
            +
                    >>> # Force addition of gt labels (when adding gt as proposals)
         
     | 
| 36 | 
         
            +
                    >>> new_labels = torch.LongTensor([3, 4, 5])
         
     | 
| 37 | 
         
            +
                    >>> self.add_gt_(new_labels)
         
     | 
| 38 | 
         
            +
                    >>> print(str(self))  # xdoctest: +IGNORE_WANT
         
     | 
| 39 | 
         
            +
                    <AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
         
     | 
| 40 | 
         
            +
                                  labels.shape=(7,))>
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
         
     | 
| 44 | 
         
            +
                    self.num_gts = num_gts
         
     | 
| 45 | 
         
            +
                    self.gt_inds = gt_inds
         
     | 
| 46 | 
         
            +
                    self.max_overlaps = max_overlaps
         
     | 
| 47 | 
         
            +
                    self.labels = labels
         
     | 
| 48 | 
         
            +
                    # Interface for possible user-defined properties
         
     | 
| 49 | 
         
            +
                    self._extra_properties = {}
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                @property
         
     | 
| 52 | 
         
            +
                def num_preds(self):
         
     | 
| 53 | 
         
            +
                    """int: the number of predictions in this assignment"""
         
     | 
| 54 | 
         
            +
                    return len(self.gt_inds)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def set_extra_property(self, key, value):
         
     | 
| 57 | 
         
            +
                    """Set user-defined new property."""
         
     | 
| 58 | 
         
            +
                    assert key not in self.info
         
     | 
| 59 | 
         
            +
                    self._extra_properties[key] = value
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def get_extra_property(self, key):
         
     | 
| 62 | 
         
            +
                    """Get user-defined property."""
         
     | 
| 63 | 
         
            +
                    return self._extra_properties.get(key, None)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                @property
         
     | 
| 66 | 
         
            +
                def info(self):
         
     | 
| 67 | 
         
            +
                    """dict: a dictionary of info about the object"""
         
     | 
| 68 | 
         
            +
                    basic_info = {
         
     | 
| 69 | 
         
            +
                        'num_gts': self.num_gts,
         
     | 
| 70 | 
         
            +
                        'num_preds': self.num_preds,
         
     | 
| 71 | 
         
            +
                        'gt_inds': self.gt_inds,
         
     | 
| 72 | 
         
            +
                        'max_overlaps': self.max_overlaps,
         
     | 
| 73 | 
         
            +
                        'labels': self.labels,
         
     | 
| 74 | 
         
            +
                    }
         
     | 
| 75 | 
         
            +
                    basic_info.update(self._extra_properties)
         
     | 
| 76 | 
         
            +
                    return basic_info
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def __nice__(self):
         
     | 
| 79 | 
         
            +
                    """str: a "nice" summary string describing this assign result"""
         
     | 
| 80 | 
         
            +
                    parts = []
         
     | 
| 81 | 
         
            +
                    parts.append(f'num_gts={self.num_gts!r}')
         
     | 
| 82 | 
         
            +
                    if self.gt_inds is None:
         
     | 
| 83 | 
         
            +
                        parts.append(f'gt_inds={self.gt_inds!r}')
         
     | 
| 84 | 
         
            +
                    else:
         
     | 
| 85 | 
         
            +
                        parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
         
     | 
| 86 | 
         
            +
                    if self.max_overlaps is None:
         
     | 
| 87 | 
         
            +
                        parts.append(f'max_overlaps={self.max_overlaps!r}')
         
     | 
| 88 | 
         
            +
                    else:
         
     | 
| 89 | 
         
            +
                        parts.append('max_overlaps.shape='
         
     | 
| 90 | 
         
            +
                                     f'{tuple(self.max_overlaps.shape)!r}')
         
     | 
| 91 | 
         
            +
                    if self.labels is None:
         
     | 
| 92 | 
         
            +
                        parts.append(f'labels={self.labels!r}')
         
     | 
| 93 | 
         
            +
                    else:
         
     | 
| 94 | 
         
            +
                        parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
         
     | 
| 95 | 
         
            +
                    return ', '.join(parts)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                @classmethod
         
     | 
| 98 | 
         
            +
                def random(cls, **kwargs):
         
     | 
| 99 | 
         
            +
                    """Create random AssignResult for tests or debugging.
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    Args:
         
     | 
| 102 | 
         
            +
                        num_preds: number of predicted boxes
         
     | 
| 103 | 
         
            +
                        num_gts: number of true boxes
         
     | 
| 104 | 
         
            +
                        p_ignore (float): probability of a predicted box assigned to an
         
     | 
| 105 | 
         
            +
                            ignored truth
         
     | 
| 106 | 
         
            +
                        p_assigned (float): probability of a predicted box not being
         
     | 
| 107 | 
         
            +
                            assigned
         
     | 
| 108 | 
         
            +
                        p_use_label (float | bool): with labels or not
         
     | 
| 109 | 
         
            +
                        rng (None | int | numpy.random.RandomState): seed or state
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    Returns:
         
     | 
| 112 | 
         
            +
                        :obj:`AssignResult`: Randomly generated assign results.
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    Example:
         
     | 
| 115 | 
         
            +
                        >>> from mmdet.core.bbox.assigners.assign_result import *  # NOQA
         
     | 
| 116 | 
         
            +
                        >>> self = AssignResult.random()
         
     | 
| 117 | 
         
            +
                        >>> print(self.info)
         
     | 
| 118 | 
         
            +
                    """
         
     | 
| 119 | 
         
            +
                    from mmdet.core.bbox import demodata
         
     | 
| 120 | 
         
            +
                    rng = demodata.ensure_rng(kwargs.get('rng', None))
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    num_gts = kwargs.get('num_gts', None)
         
     | 
| 123 | 
         
            +
                    num_preds = kwargs.get('num_preds', None)
         
     | 
| 124 | 
         
            +
                    p_ignore = kwargs.get('p_ignore', 0.3)
         
     | 
| 125 | 
         
            +
                    p_assigned = kwargs.get('p_assigned', 0.7)
         
     | 
| 126 | 
         
            +
                    p_use_label = kwargs.get('p_use_label', 0.5)
         
     | 
| 127 | 
         
            +
                    num_classes = kwargs.get('p_use_label', 3)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    if num_gts is None:
         
     | 
| 130 | 
         
            +
                        num_gts = rng.randint(0, 8)
         
     | 
| 131 | 
         
            +
                    if num_preds is None:
         
     | 
| 132 | 
         
            +
                        num_preds = rng.randint(0, 16)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    if num_gts == 0:
         
     | 
| 135 | 
         
            +
                        max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
         
     | 
| 136 | 
         
            +
                        gt_inds = torch.zeros(num_preds, dtype=torch.int64)
         
     | 
| 137 | 
         
            +
                        if p_use_label is True or p_use_label < rng.rand():
         
     | 
| 138 | 
         
            +
                            labels = torch.zeros(num_preds, dtype=torch.int64)
         
     | 
| 139 | 
         
            +
                        else:
         
     | 
| 140 | 
         
            +
                            labels = None
         
     | 
| 141 | 
         
            +
                    else:
         
     | 
| 142 | 
         
            +
                        import numpy as np
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                        # Create an overlap for each predicted box
         
     | 
| 145 | 
         
            +
                        max_overlaps = torch.from_numpy(rng.rand(num_preds))
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                        # Construct gt_inds for each predicted box
         
     | 
| 148 | 
         
            +
                        is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
         
     | 
| 149 | 
         
            +
                        # maximum number of assignments constraints
         
     | 
| 150 | 
         
            +
                        n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                        assigned_idxs = np.where(is_assigned)[0]
         
     | 
| 153 | 
         
            +
                        rng.shuffle(assigned_idxs)
         
     | 
| 154 | 
         
            +
                        assigned_idxs = assigned_idxs[0:n_assigned]
         
     | 
| 155 | 
         
            +
                        assigned_idxs.sort()
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                        is_assigned[:] = 0
         
     | 
| 158 | 
         
            +
                        is_assigned[assigned_idxs] = True
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                        is_ignore = torch.from_numpy(
         
     | 
| 161 | 
         
            +
                            rng.rand(num_preds) < p_ignore) & is_assigned
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                        gt_inds = torch.zeros(num_preds, dtype=torch.int64)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                        true_idxs = np.arange(num_gts)
         
     | 
| 166 | 
         
            +
                        rng.shuffle(true_idxs)
         
     | 
| 167 | 
         
            +
                        true_idxs = torch.from_numpy(true_idxs)
         
     | 
| 168 | 
         
            +
                        gt_inds[is_assigned] = true_idxs[:n_assigned].long()
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                        gt_inds = torch.from_numpy(
         
     | 
| 171 | 
         
            +
                            rng.randint(1, num_gts + 1, size=num_preds))
         
     | 
| 172 | 
         
            +
                        gt_inds[is_ignore] = -1
         
     | 
| 173 | 
         
            +
                        gt_inds[~is_assigned] = 0
         
     | 
| 174 | 
         
            +
                        max_overlaps[~is_assigned] = 0
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                        if p_use_label is True or p_use_label < rng.rand():
         
     | 
| 177 | 
         
            +
                            if num_classes == 0:
         
     | 
| 178 | 
         
            +
                                labels = torch.zeros(num_preds, dtype=torch.int64)
         
     | 
| 179 | 
         
            +
                            else:
         
     | 
| 180 | 
         
            +
                                labels = torch.from_numpy(
         
     | 
| 181 | 
         
            +
                                    # remind that we set FG labels to [0, num_class-1]
         
     | 
| 182 | 
         
            +
                                    # since mmdet v2.0
         
     | 
| 183 | 
         
            +
                                    # BG cat_id: num_class
         
     | 
| 184 | 
         
            +
                                    rng.randint(0, num_classes, size=num_preds))
         
     | 
| 185 | 
         
            +
                                labels[~is_assigned] = 0
         
     | 
| 186 | 
         
            +
                        else:
         
     | 
| 187 | 
         
            +
                            labels = None
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    self = cls(num_gts, gt_inds, max_overlaps, labels)
         
     | 
| 190 | 
         
            +
                    return self
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def add_gt_(self, gt_labels):
         
     | 
| 193 | 
         
            +
                    """Add ground truth as assigned results.
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    Args:
         
     | 
| 196 | 
         
            +
                        gt_labels (torch.Tensor): Labels of gt boxes
         
     | 
| 197 | 
         
            +
                    """
         
     | 
| 198 | 
         
            +
                    self_inds = torch.arange(
         
     | 
| 199 | 
         
            +
                        1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
         
     | 
| 200 | 
         
            +
                    self.gt_inds = torch.cat([self_inds, self.gt_inds])
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    self.max_overlaps = torch.cat(
         
     | 
| 203 | 
         
            +
                        [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    if self.labels is not None:
         
     | 
| 206 | 
         
            +
                        self.labels = torch.cat([gt_labels, self.labels])
         
     | 
    	
        mmdet/core/bbox/assigners/atss_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,234 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 7 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 8 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 9 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 13 | 
         
            +
            class ATSSAssigner(BaseAssigner):
         
     | 
| 14 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Each proposals will be assigned with `0` or a positive integer
         
     | 
| 17 | 
         
            +
                indicating the ground truth index.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 20 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                If ``alpha`` is not None, it means that the dynamic cost
         
     | 
| 23 | 
         
            +
                ATSSAssigner is adopted, which is currently only used in the DDOD.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Args:
         
     | 
| 26 | 
         
            +
                    topk (float): number of bbox selected in each level
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def __init__(self,
         
     | 
| 30 | 
         
            +
                             topk,
         
     | 
| 31 | 
         
            +
                             alpha=None,
         
     | 
| 32 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D'),
         
     | 
| 33 | 
         
            +
                             ignore_iof_thr=-1):
         
     | 
| 34 | 
         
            +
                    self.topk = topk
         
     | 
| 35 | 
         
            +
                    self.alpha = alpha
         
     | 
| 36 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 37 | 
         
            +
                    self.ignore_iof_thr = ignore_iof_thr
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                Args:
         
     | 
| 42 | 
         
            +
                    topk (int): number of bbox selected in each level.
         
     | 
| 43 | 
         
            +
                    alpha (float): param of cost rate for each proposal only in DDOD.
         
     | 
| 44 | 
         
            +
                        Default None.
         
     | 
| 45 | 
         
            +
                    iou_calculator (dict): builder of IoU calculator.
         
     | 
| 46 | 
         
            +
                        Default dict(type='BboxOverlaps2D').
         
     | 
| 47 | 
         
            +
                    ignore_iof_thr (int): whether ignore max overlaps or not.
         
     | 
| 48 | 
         
            +
                        Default -1 (1 or -1).
         
     | 
| 49 | 
         
            +
                """
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
         
     | 
| 52 | 
         
            +
                def assign(self,
         
     | 
| 53 | 
         
            +
                           bboxes,
         
     | 
| 54 | 
         
            +
                           num_level_bboxes,
         
     | 
| 55 | 
         
            +
                           gt_bboxes,
         
     | 
| 56 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 57 | 
         
            +
                           gt_labels=None,
         
     | 
| 58 | 
         
            +
                           cls_scores=None,
         
     | 
| 59 | 
         
            +
                           bbox_preds=None):
         
     | 
| 60 | 
         
            +
                    """Assign gt to bboxes.
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    The assignment is done in following steps
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    1. compute iou between all bbox (bbox of all pyramid levels) and gt
         
     | 
| 65 | 
         
            +
                    2. compute center distance between all bbox and gt
         
     | 
| 66 | 
         
            +
                    3. on each pyramid level, for each gt, select k bbox whose center
         
     | 
| 67 | 
         
            +
                       are closest to the gt center, so we total select k*l bbox as
         
     | 
| 68 | 
         
            +
                       candidates for each gt
         
     | 
| 69 | 
         
            +
                    4. get corresponding iou for the these candidates, and compute the
         
     | 
| 70 | 
         
            +
                       mean and std, set mean + std as the iou threshold
         
     | 
| 71 | 
         
            +
                    5. select these candidates whose iou are greater than or equal to
         
     | 
| 72 | 
         
            +
                       the threshold as positive
         
     | 
| 73 | 
         
            +
                    6. limit the positive sample's center in gt
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    If ``alpha`` is not None, and ``cls_scores`` and `bbox_preds`
         
     | 
| 76 | 
         
            +
                    are not None, the overlaps calculation in the first step
         
     | 
| 77 | 
         
            +
                    will also include dynamic cost, which is currently only used in
         
     | 
| 78 | 
         
            +
                    the DDOD.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    Args:
         
     | 
| 81 | 
         
            +
                        bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
         
     | 
| 82 | 
         
            +
                        num_level_bboxes (List): num of bboxes in each level
         
     | 
| 83 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 84 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 85 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO. Default None.
         
     | 
| 86 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 87 | 
         
            +
                        cls_scores (list[Tensor]): Classification scores for all scale
         
     | 
| 88 | 
         
            +
                            levels, each is a 4D-tensor, the channels number is
         
     | 
| 89 | 
         
            +
                            num_base_priors * num_classes. Default None.
         
     | 
| 90 | 
         
            +
                        bbox_preds (list[Tensor]): Box energies / deltas for all scale
         
     | 
| 91 | 
         
            +
                            levels, each is a 4D-tensor, the channels number is
         
     | 
| 92 | 
         
            +
                            num_base_priors * 4. Default None.
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    Returns:
         
     | 
| 95 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 96 | 
         
            +
                    """
         
     | 
| 97 | 
         
            +
                    INF = 100000000
         
     | 
| 98 | 
         
            +
                    bboxes = bboxes[:, :4]
         
     | 
| 99 | 
         
            +
                    num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    message = 'Invalid alpha parameter because cls_scores or ' \
         
     | 
| 102 | 
         
            +
                              'bbox_preds are None. If you want to use the ' \
         
     | 
| 103 | 
         
            +
                              'cost-based ATSSAssigner,  please set cls_scores, ' \
         
     | 
| 104 | 
         
            +
                              'bbox_preds and self.alpha at the same time. '
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    if self.alpha is None:
         
     | 
| 107 | 
         
            +
                        # ATSSAssigner
         
     | 
| 108 | 
         
            +
                        overlaps = self.iou_calculator(bboxes, gt_bboxes)
         
     | 
| 109 | 
         
            +
                        if cls_scores is not None or bbox_preds is not None:
         
     | 
| 110 | 
         
            +
                            warnings.warn(message)
         
     | 
| 111 | 
         
            +
                    else:
         
     | 
| 112 | 
         
            +
                        # Dynamic cost ATSSAssigner in DDOD
         
     | 
| 113 | 
         
            +
                        assert cls_scores is not None and bbox_preds is not None, message
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                        # compute cls cost for bbox and GT
         
     | 
| 116 | 
         
            +
                        cls_cost = torch.sigmoid(cls_scores[:, gt_labels])
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                        # compute iou between all bbox and gt
         
     | 
| 119 | 
         
            +
                        overlaps = self.iou_calculator(bbox_preds, gt_bboxes)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                        # make sure that we are in element-wise multiplication
         
     | 
| 122 | 
         
            +
                        assert cls_cost.shape == overlaps.shape
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                        # overlaps is actually a cost matrix
         
     | 
| 125 | 
         
            +
                        overlaps = cls_cost**(1 - self.alpha) * overlaps**self.alpha
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    # assign 0 by default
         
     | 
| 128 | 
         
            +
                    assigned_gt_inds = overlaps.new_full((num_bboxes, ),
         
     | 
| 129 | 
         
            +
                                                         0,
         
     | 
| 130 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if num_gt == 0 or num_bboxes == 0:
         
     | 
| 133 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 134 | 
         
            +
                        max_overlaps = overlaps.new_zeros((num_bboxes, ))
         
     | 
| 135 | 
         
            +
                        if num_gt == 0:
         
     | 
| 136 | 
         
            +
                            # No truth, assign everything to background
         
     | 
| 137 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 138 | 
         
            +
                        if gt_labels is None:
         
     | 
| 139 | 
         
            +
                            assigned_labels = None
         
     | 
| 140 | 
         
            +
                        else:
         
     | 
| 141 | 
         
            +
                            assigned_labels = overlaps.new_full((num_bboxes, ),
         
     | 
| 142 | 
         
            +
                                                                -1,
         
     | 
| 143 | 
         
            +
                                                                dtype=torch.long)
         
     | 
| 144 | 
         
            +
                        return AssignResult(
         
     | 
| 145 | 
         
            +
                            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    # compute center distance between all bbox and gt
         
     | 
| 148 | 
         
            +
                    gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
         
     | 
| 149 | 
         
            +
                    gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
         
     | 
| 150 | 
         
            +
                    gt_points = torch.stack((gt_cx, gt_cy), dim=1)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
         
     | 
| 153 | 
         
            +
                    bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
         
     | 
| 154 | 
         
            +
                    bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    distances = (bboxes_points[:, None, :] -
         
     | 
| 157 | 
         
            +
                                 gt_points[None, :, :]).pow(2).sum(-1).sqrt()
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
         
     | 
| 160 | 
         
            +
                            and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
         
     | 
| 161 | 
         
            +
                        ignore_overlaps = self.iou_calculator(
         
     | 
| 162 | 
         
            +
                            bboxes, gt_bboxes_ignore, mode='iof')
         
     | 
| 163 | 
         
            +
                        ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
         
     | 
| 164 | 
         
            +
                        ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
         
     | 
| 165 | 
         
            +
                        distances[ignore_idxs, :] = INF
         
     | 
| 166 | 
         
            +
                        assigned_gt_inds[ignore_idxs] = -1
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    # Selecting candidates based on the center distance
         
     | 
| 169 | 
         
            +
                    candidate_idxs = []
         
     | 
| 170 | 
         
            +
                    start_idx = 0
         
     | 
| 171 | 
         
            +
                    for level, bboxes_per_level in enumerate(num_level_bboxes):
         
     | 
| 172 | 
         
            +
                        # on each pyramid level, for each gt,
         
     | 
| 173 | 
         
            +
                        # select k bbox whose center are closest to the gt center
         
     | 
| 174 | 
         
            +
                        end_idx = start_idx + bboxes_per_level
         
     | 
| 175 | 
         
            +
                        distances_per_level = distances[start_idx:end_idx, :]
         
     | 
| 176 | 
         
            +
                        selectable_k = min(self.topk, bboxes_per_level)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                        _, topk_idxs_per_level = distances_per_level.topk(
         
     | 
| 179 | 
         
            +
                            selectable_k, dim=0, largest=False)
         
     | 
| 180 | 
         
            +
                        candidate_idxs.append(topk_idxs_per_level + start_idx)
         
     | 
| 181 | 
         
            +
                        start_idx = end_idx
         
     | 
| 182 | 
         
            +
                    candidate_idxs = torch.cat(candidate_idxs, dim=0)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    # get corresponding iou for the these candidates, and compute the
         
     | 
| 185 | 
         
            +
                    # mean and std, set mean + std as the iou threshold
         
     | 
| 186 | 
         
            +
                    candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
         
     | 
| 187 | 
         
            +
                    overlaps_mean_per_gt = candidate_overlaps.mean(0)
         
     | 
| 188 | 
         
            +
                    overlaps_std_per_gt = candidate_overlaps.std(0)
         
     | 
| 189 | 
         
            +
                    overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    # limit the positive sample's center in gt
         
     | 
| 194 | 
         
            +
                    for gt_idx in range(num_gt):
         
     | 
| 195 | 
         
            +
                        candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
         
     | 
| 196 | 
         
            +
                    ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
         
     | 
| 197 | 
         
            +
                        num_gt, num_bboxes).contiguous().view(-1)
         
     | 
| 198 | 
         
            +
                    ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
         
     | 
| 199 | 
         
            +
                        num_gt, num_bboxes).contiguous().view(-1)
         
     | 
| 200 | 
         
            +
                    candidate_idxs = candidate_idxs.view(-1)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    # calculate the left, top, right, bottom distance between positive
         
     | 
| 203 | 
         
            +
                    # bbox center and gt side
         
     | 
| 204 | 
         
            +
                    l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
         
     | 
| 205 | 
         
            +
                    t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
         
     | 
| 206 | 
         
            +
                    r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
         
     | 
| 207 | 
         
            +
                    b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
         
     | 
| 208 | 
         
            +
                    is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    is_pos = is_pos & is_in_gts
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    # if an anchor box is assigned to multiple gts,
         
     | 
| 213 | 
         
            +
                    # the one with the highest IoU will be selected.
         
     | 
| 214 | 
         
            +
                    overlaps_inf = torch.full_like(overlaps,
         
     | 
| 215 | 
         
            +
                                                   -INF).t().contiguous().view(-1)
         
     | 
| 216 | 
         
            +
                    index = candidate_idxs.view(-1)[is_pos.view(-1)]
         
     | 
| 217 | 
         
            +
                    overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
         
     | 
| 218 | 
         
            +
                    overlaps_inf = overlaps_inf.view(num_gt, -1).t()
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
         
     | 
| 221 | 
         
            +
                    assigned_gt_inds[
         
     | 
| 222 | 
         
            +
                        max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 225 | 
         
            +
                        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
         
     | 
| 226 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 227 | 
         
            +
                            assigned_gt_inds > 0, as_tuple=False).squeeze()
         
     | 
| 228 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 229 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[
         
     | 
| 230 | 
         
            +
                                assigned_gt_inds[pos_inds] - 1]
         
     | 
| 231 | 
         
            +
                    else:
         
     | 
| 232 | 
         
            +
                        assigned_labels = None
         
     | 
| 233 | 
         
            +
                    return AssignResult(
         
     | 
| 234 | 
         
            +
                        num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/base_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from abc import ABCMeta, abstractmethod
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class BaseAssigner(metaclass=ABCMeta):
         
     | 
| 6 | 
         
            +
                """Base assigner that assigns boxes to ground truth boxes."""
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                @abstractmethod
         
     | 
| 9 | 
         
            +
                def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
         
     | 
| 10 | 
         
            +
                    """Assign boxes to either a ground truth boxes or a negative boxes."""
         
     | 
    	
        mmdet/core/bbox/assigners/center_region_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,336 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 6 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 7 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def scale_boxes(bboxes, scale):
         
     | 
| 11 | 
         
            +
                """Expand an array of boxes by a given scale.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Args:
         
     | 
| 14 | 
         
            +
                    bboxes (Tensor): Shape (m, 4)
         
     | 
| 15 | 
         
            +
                    scale (float): The scale factor of bboxes
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                Returns:
         
     | 
| 18 | 
         
            +
                    (Tensor): Shape (m, 4). Scaled bboxes
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
                assert bboxes.size(1) == 4
         
     | 
| 21 | 
         
            +
                w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
         
     | 
| 22 | 
         
            +
                h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
         
     | 
| 23 | 
         
            +
                x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
         
     | 
| 24 | 
         
            +
                y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                w_half *= scale
         
     | 
| 27 | 
         
            +
                h_half *= scale
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                boxes_scaled = torch.zeros_like(bboxes)
         
     | 
| 30 | 
         
            +
                boxes_scaled[:, 0] = x_c - w_half
         
     | 
| 31 | 
         
            +
                boxes_scaled[:, 2] = x_c + w_half
         
     | 
| 32 | 
         
            +
                boxes_scaled[:, 1] = y_c - h_half
         
     | 
| 33 | 
         
            +
                boxes_scaled[:, 3] = y_c + h_half
         
     | 
| 34 | 
         
            +
                return boxes_scaled
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def is_located_in(points, bboxes):
         
     | 
| 38 | 
         
            +
                """Are points located in bboxes.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                Args:
         
     | 
| 41 | 
         
            +
                  points (Tensor): Points, shape: (m, 2).
         
     | 
| 42 | 
         
            +
                  bboxes (Tensor): Bounding boxes, shape: (n, 4).
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                Return:
         
     | 
| 45 | 
         
            +
                  Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
         
     | 
| 46 | 
         
            +
                """
         
     | 
| 47 | 
         
            +
                assert points.size(1) == 2
         
     | 
| 48 | 
         
            +
                assert bboxes.size(1) == 4
         
     | 
| 49 | 
         
            +
                return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
         
     | 
| 50 | 
         
            +
                       (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
         
     | 
| 51 | 
         
            +
                       (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
         
     | 
| 52 | 
         
            +
                       (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def bboxes_area(bboxes):
         
     | 
| 56 | 
         
            +
                """Compute the area of an array of bboxes.
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                Args:
         
     | 
| 59 | 
         
            +
                    bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                Returns:
         
     | 
| 62 | 
         
            +
                    Tensor: Area of the bboxes. Shape: (m, )
         
     | 
| 63 | 
         
            +
                """
         
     | 
| 64 | 
         
            +
                assert bboxes.size(1) == 4
         
     | 
| 65 | 
         
            +
                w = (bboxes[:, 2] - bboxes[:, 0])
         
     | 
| 66 | 
         
            +
                h = (bboxes[:, 3] - bboxes[:, 1])
         
     | 
| 67 | 
         
            +
                areas = w * h
         
     | 
| 68 | 
         
            +
                return areas
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 72 | 
         
            +
            class CenterRegionAssigner(BaseAssigner):
         
     | 
| 73 | 
         
            +
                """Assign pixels at the center region of a bbox as positive.
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                Each proposals will be assigned with `-1`, `0`, or a positive integer
         
     | 
| 76 | 
         
            +
                indicating the ground truth index.
         
     | 
| 77 | 
         
            +
                - -1: negative samples
         
     | 
| 78 | 
         
            +
                - semi-positive numbers: positive sample, index (0-based) of assigned gt
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                Args:
         
     | 
| 81 | 
         
            +
                    pos_scale (float): Threshold within which pixels are
         
     | 
| 82 | 
         
            +
                      labelled as positive.
         
     | 
| 83 | 
         
            +
                    neg_scale (float): Threshold above which pixels are
         
     | 
| 84 | 
         
            +
                      labelled as positive.
         
     | 
| 85 | 
         
            +
                    min_pos_iof (float): Minimum iof of a pixel with a gt to be
         
     | 
| 86 | 
         
            +
                      labelled as positive. Default: 1e-2
         
     | 
| 87 | 
         
            +
                    ignore_gt_scale (float): Threshold within which the pixels
         
     | 
| 88 | 
         
            +
                      are ignored when the gt is labelled as shadowed. Default: 0.5
         
     | 
| 89 | 
         
            +
                    foreground_dominate (bool): If True, the bbox will be assigned as
         
     | 
| 90 | 
         
            +
                      positive when a gt's kernel region overlaps with another's shadowed
         
     | 
| 91 | 
         
            +
                      (ignored) region, otherwise it is set as ignored. Default to False.
         
     | 
| 92 | 
         
            +
                """
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def __init__(self,
         
     | 
| 95 | 
         
            +
                             pos_scale,
         
     | 
| 96 | 
         
            +
                             neg_scale,
         
     | 
| 97 | 
         
            +
                             min_pos_iof=1e-2,
         
     | 
| 98 | 
         
            +
                             ignore_gt_scale=0.5,
         
     | 
| 99 | 
         
            +
                             foreground_dominate=False,
         
     | 
| 100 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 101 | 
         
            +
                    self.pos_scale = pos_scale
         
     | 
| 102 | 
         
            +
                    self.neg_scale = neg_scale
         
     | 
| 103 | 
         
            +
                    self.min_pos_iof = min_pos_iof
         
     | 
| 104 | 
         
            +
                    self.ignore_gt_scale = ignore_gt_scale
         
     | 
| 105 | 
         
            +
                    self.foreground_dominate = foreground_dominate
         
     | 
| 106 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def get_gt_priorities(self, gt_bboxes):
         
     | 
| 109 | 
         
            +
                    """Get gt priorities according to their areas.
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    Smaller gt has higher priority.
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    Args:
         
     | 
| 114 | 
         
            +
                        gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    Returns:
         
     | 
| 117 | 
         
            +
                        Tensor: The priority of gts so that gts with larger priority is \
         
     | 
| 118 | 
         
            +
                          more likely to be assigned. Shape (k, )
         
     | 
| 119 | 
         
            +
                    """
         
     | 
| 120 | 
         
            +
                    gt_areas = bboxes_area(gt_bboxes)
         
     | 
| 121 | 
         
            +
                    # Rank all gt bbox areas. Smaller objects has larger priority
         
     | 
| 122 | 
         
            +
                    _, sort_idx = gt_areas.sort(descending=True)
         
     | 
| 123 | 
         
            +
                    sort_idx = sort_idx.argsort()
         
     | 
| 124 | 
         
            +
                    return sort_idx
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
         
     | 
| 127 | 
         
            +
                    """Assign gt to bboxes.
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    This method assigns gts to every bbox (proposal/anchor), each bbox \
         
     | 
| 130 | 
         
            +
                    will be assigned with -1, or a semi-positive number. -1 means \
         
     | 
| 131 | 
         
            +
                    negative sample, semi-positive number is the index (0-based) of \
         
     | 
| 132 | 
         
            +
                    assigned gt.
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    Args:
         
     | 
| 135 | 
         
            +
                        bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
         
     | 
| 136 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 137 | 
         
            +
                        gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
         
     | 
| 138 | 
         
            +
                          labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 139 | 
         
            +
                        gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    Returns:
         
     | 
| 142 | 
         
            +
                        :obj:`AssignResult`: The assigned result. Note that \
         
     | 
| 143 | 
         
            +
                          shadowed_labels of shape (N, 2) is also added as an \
         
     | 
| 144 | 
         
            +
                          `assign_result` attribute. `shadowed_labels` is a tensor \
         
     | 
| 145 | 
         
            +
                          composed of N pairs of anchor_ind, class_label], where N \
         
     | 
| 146 | 
         
            +
                          is the number of anchors that lie in the outer region of a \
         
     | 
| 147 | 
         
            +
                          gt, anchor_ind is the shadowed anchor index and class_label \
         
     | 
| 148 | 
         
            +
                          is the shadowed class label.
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    Example:
         
     | 
| 151 | 
         
            +
                        >>> self = CenterRegionAssigner(0.2, 0.2)
         
     | 
| 152 | 
         
            +
                        >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
         
     | 
| 153 | 
         
            +
                        >>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
         
     | 
| 154 | 
         
            +
                        >>> assign_result = self.assign(bboxes, gt_bboxes)
         
     | 
| 155 | 
         
            +
                        >>> expected_gt_inds = torch.LongTensor([1, 0])
         
     | 
| 156 | 
         
            +
                        >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
         
     | 
| 157 | 
         
            +
                    """
         
     | 
| 158 | 
         
            +
                    # There are in total 5 steps in the pixel assignment
         
     | 
| 159 | 
         
            +
                    # 1. Find core (the center region, say inner 0.2)
         
     | 
| 160 | 
         
            +
                    #     and shadow (the relatively ourter part, say inner 0.2-0.5)
         
     | 
| 161 | 
         
            +
                    #     regions of every gt.
         
     | 
| 162 | 
         
            +
                    # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
         
     | 
| 163 | 
         
            +
                    # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
         
     | 
| 164 | 
         
            +
                    #      the image.
         
     | 
| 165 | 
         
            +
                    #    3.1. For overlapping objects, the prior bboxes in gt_core is
         
     | 
| 166 | 
         
            +
                    #           assigned with the object with smallest area
         
     | 
| 167 | 
         
            +
                    # 4. Assign prior bboxes with class label according to its gt id.
         
     | 
| 168 | 
         
            +
                    #    4.1. Assign -1 to prior bboxes lying in shadowed gts
         
     | 
| 169 | 
         
            +
                    #    4.2. Assign positive prior boxes with the corresponding label
         
     | 
| 170 | 
         
            +
                    # 5. Find pixels lying in the shadow of an object and assign them with
         
     | 
| 171 | 
         
            +
                    #      background label, but set the loss weight of its corresponding
         
     | 
| 172 | 
         
            +
                    #      gt to zero.
         
     | 
| 173 | 
         
            +
                    assert bboxes.size(1) == 4, 'bboxes must have size of 4'
         
     | 
| 174 | 
         
            +
                    # 1. Find core positive and shadow region of every gt
         
     | 
| 175 | 
         
            +
                    gt_core = scale_boxes(gt_bboxes, self.pos_scale)
         
     | 
| 176 | 
         
            +
                    gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    # 2. Find prior bboxes that lie in gt_core and gt_shadow regions
         
     | 
| 179 | 
         
            +
                    bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
         
     | 
| 180 | 
         
            +
                    # The center points lie within the gt boxes
         
     | 
| 181 | 
         
            +
                    is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
         
     | 
| 182 | 
         
            +
                    # Only calculate bbox and gt_core IoF. This enables small prior bboxes
         
     | 
| 183 | 
         
            +
                    #   to match large gts
         
     | 
| 184 | 
         
            +
                    bbox_and_gt_core_overlaps = self.iou_calculator(
         
     | 
| 185 | 
         
            +
                        bboxes, gt_core, mode='iof')
         
     | 
| 186 | 
         
            +
                    # The center point of effective priors should be within the gt box
         
     | 
| 187 | 
         
            +
                    is_bbox_in_gt_core = is_bbox_in_gt & (
         
     | 
| 188 | 
         
            +
                        bbox_and_gt_core_overlaps > self.min_pos_iof)  # shape (n, k)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    is_bbox_in_gt_shadow = (
         
     | 
| 191 | 
         
            +
                        self.iou_calculator(bboxes, gt_shadow, mode='iof') >
         
     | 
| 192 | 
         
            +
                        self.min_pos_iof)
         
     | 
| 193 | 
         
            +
                    # Rule out center effective positive pixels
         
     | 
| 194 | 
         
            +
                    is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
         
     | 
| 197 | 
         
            +
                    if num_gts == 0 or num_bboxes == 0:
         
     | 
| 198 | 
         
            +
                        # If no gts exist, assign all pixels to negative
         
     | 
| 199 | 
         
            +
                        assigned_gt_ids = \
         
     | 
| 200 | 
         
            +
                            is_bbox_in_gt_core.new_zeros((num_bboxes,),
         
     | 
| 201 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 202 | 
         
            +
                        pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
         
     | 
| 203 | 
         
            +
                    else:
         
     | 
| 204 | 
         
            +
                        # Step 3: assign a one-hot gt id to each pixel, and smaller objects
         
     | 
| 205 | 
         
            +
                        #    have high priority to assign the pixel.
         
     | 
| 206 | 
         
            +
                        sort_idx = self.get_gt_priorities(gt_bboxes)
         
     | 
| 207 | 
         
            +
                        assigned_gt_ids, pixels_in_gt_shadow = \
         
     | 
| 208 | 
         
            +
                            self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
         
     | 
| 209 | 
         
            +
                                                           is_bbox_in_gt_shadow,
         
     | 
| 210 | 
         
            +
                                                           gt_priority=sort_idx)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
         
     | 
| 213 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 214 | 
         
            +
                        gt_bboxes_ignore = scale_boxes(
         
     | 
| 215 | 
         
            +
                            gt_bboxes_ignore, scale=self.ignore_gt_scale)
         
     | 
| 216 | 
         
            +
                        is_bbox_in_ignored_gts = is_located_in(bbox_centers,
         
     | 
| 217 | 
         
            +
                                                               gt_bboxes_ignore)
         
     | 
| 218 | 
         
            +
                        is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
         
     | 
| 219 | 
         
            +
                        assigned_gt_ids[is_bbox_in_ignored_gts] = -1
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    # 4. Assign prior bboxes with class label according to its gt id.
         
     | 
| 222 | 
         
            +
                    assigned_labels = None
         
     | 
| 223 | 
         
            +
                    shadowed_pixel_labels = None
         
     | 
| 224 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 225 | 
         
            +
                        # Default assigned label is the background (-1)
         
     | 
| 226 | 
         
            +
                        assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
         
     | 
| 227 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 228 | 
         
            +
                            assigned_gt_ids > 0, as_tuple=False).squeeze()
         
     | 
| 229 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 230 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
         
     | 
| 231 | 
         
            +
                                                                  - 1]
         
     | 
| 232 | 
         
            +
                        # 5. Find pixels lying in the shadow of an object
         
     | 
| 233 | 
         
            +
                        shadowed_pixel_labels = pixels_in_gt_shadow.clone()
         
     | 
| 234 | 
         
            +
                        if pixels_in_gt_shadow.numel() > 0:
         
     | 
| 235 | 
         
            +
                            pixel_idx, gt_idx =\
         
     | 
| 236 | 
         
            +
                                pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
         
     | 
| 237 | 
         
            +
                            assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
         
     | 
| 238 | 
         
            +
                                'Some pixels are dually assigned to ignore and gt!'
         
     | 
| 239 | 
         
            +
                            shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
         
     | 
| 240 | 
         
            +
                            override = (
         
     | 
| 241 | 
         
            +
                                assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
         
     | 
| 242 | 
         
            +
                            if self.foreground_dominate:
         
     | 
| 243 | 
         
            +
                                # When a pixel is both positive and shadowed, set it as pos
         
     | 
| 244 | 
         
            +
                                shadowed_pixel_labels = shadowed_pixel_labels[~override]
         
     | 
| 245 | 
         
            +
                            else:
         
     | 
| 246 | 
         
            +
                                # When a pixel is both pos and shadowed, set it as shadowed
         
     | 
| 247 | 
         
            +
                                assigned_labels[pixel_idx[override]] = -1
         
     | 
| 248 | 
         
            +
                                assigned_gt_ids[pixel_idx[override]] = 0
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    assign_result = AssignResult(
         
     | 
| 251 | 
         
            +
                        num_gts, assigned_gt_ids, None, labels=assigned_labels)
         
     | 
| 252 | 
         
            +
                    # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
         
     | 
| 253 | 
         
            +
                    assign_result.set_extra_property('shadowed_labels',
         
     | 
| 254 | 
         
            +
                                                     shadowed_pixel_labels)
         
     | 
| 255 | 
         
            +
                    return assign_result
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                def assign_one_hot_gt_indices(self,
         
     | 
| 258 | 
         
            +
                                              is_bbox_in_gt_core,
         
     | 
| 259 | 
         
            +
                                              is_bbox_in_gt_shadow,
         
     | 
| 260 | 
         
            +
                                              gt_priority=None):
         
     | 
| 261 | 
         
            +
                    """Assign only one gt index to each prior box.
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    Gts with large gt_priority are more likely to be assigned.
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    Args:
         
     | 
| 266 | 
         
            +
                        is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
         
     | 
| 267 | 
         
            +
                          is in the core area of a gt (e.g. 0-0.2).
         
     | 
| 268 | 
         
            +
                          Shape: (num_prior, num_gt).
         
     | 
| 269 | 
         
            +
                        is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
         
     | 
| 270 | 
         
            +
                          center is in the shadowed area of a gt (e.g. 0.2-0.5).
         
     | 
| 271 | 
         
            +
                          Shape: (num_prior, num_gt).
         
     | 
| 272 | 
         
            +
                        gt_priority (Tensor): Priorities of gts. The gt with a higher
         
     | 
| 273 | 
         
            +
                          priority is more likely to be assigned to the bbox when the bbox
         
     | 
| 274 | 
         
            +
                          match with multiple gts. Shape: (num_gt, ).
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    Returns:
         
     | 
| 277 | 
         
            +
                        tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                            - assigned_gt_inds: The assigned gt index of each prior bbox \
         
     | 
| 280 | 
         
            +
                                (i.e. index from 1 to num_gts). Shape: (num_prior, ).
         
     | 
| 281 | 
         
            +
                            - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
         
     | 
| 282 | 
         
            +
                                shape (num_ignore, 2) with first column being the \
         
     | 
| 283 | 
         
            +
                                shadowed prior bbox indices and the second column the \
         
     | 
| 284 | 
         
            +
                                shadowed gt indices (1-based).
         
     | 
| 285 | 
         
            +
                    """
         
     | 
| 286 | 
         
            +
                    num_bboxes, num_gts = is_bbox_in_gt_core.shape
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    if gt_priority is None:
         
     | 
| 289 | 
         
            +
                        gt_priority = torch.arange(
         
     | 
| 290 | 
         
            +
                            num_gts, device=is_bbox_in_gt_core.device)
         
     | 
| 291 | 
         
            +
                    assert gt_priority.size(0) == num_gts
         
     | 
| 292 | 
         
            +
                    # The bigger gt_priority, the more preferable to be assigned
         
     | 
| 293 | 
         
            +
                    # The assigned inds are by default 0 (background)
         
     | 
| 294 | 
         
            +
                    assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
         
     | 
| 295 | 
         
            +
                                                                    dtype=torch.long)
         
     | 
| 296 | 
         
            +
                    # Shadowed bboxes are assigned to be background. But the corresponding
         
     | 
| 297 | 
         
            +
                    #   label is ignored during loss calculation, which is done through
         
     | 
| 298 | 
         
            +
                    #   shadowed_gt_inds
         
     | 
| 299 | 
         
            +
                    shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
         
     | 
| 300 | 
         
            +
                    if is_bbox_in_gt_core.sum() == 0:  # No gt match
         
     | 
| 301 | 
         
            +
                        shadowed_gt_inds[:, 1] += 1  # 1-based. For consistency issue
         
     | 
| 302 | 
         
            +
                        return assigned_gt_inds, shadowed_gt_inds
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    # The priority of each prior box and gt pair. If one prior box is
         
     | 
| 305 | 
         
            +
                    #  matched bo multiple gts. Only the pair with the highest priority
         
     | 
| 306 | 
         
            +
                    #  is saved
         
     | 
| 307 | 
         
            +
                    pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
         
     | 
| 308 | 
         
            +
                                                                -1,
         
     | 
| 309 | 
         
            +
                                                                dtype=torch.long)
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    # Each bbox could match with multiple gts.
         
     | 
| 312 | 
         
            +
                    # The following codes deal with this situation
         
     | 
| 313 | 
         
            +
                    # Matched  bboxes (to any gt). Shape: (num_pos_anchor, )
         
     | 
| 314 | 
         
            +
                    inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
         
     | 
| 315 | 
         
            +
                    # The matched gt index of each positive bbox. Length >= num_pos_anchor
         
     | 
| 316 | 
         
            +
                    #   , since one bbox could match multiple gts
         
     | 
| 317 | 
         
            +
                    matched_bbox_gt_inds = torch.nonzero(
         
     | 
| 318 | 
         
            +
                        is_bbox_in_gt_core, as_tuple=False)[:, 1]
         
     | 
| 319 | 
         
            +
                    # Assign priority to each bbox-gt pair.
         
     | 
| 320 | 
         
            +
                    pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
         
     | 
| 321 | 
         
            +
                    _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
         
     | 
| 322 | 
         
            +
                    assigned_gt_inds[inds_of_match] = argmax_priority + 1  # 1-based
         
     | 
| 323 | 
         
            +
                    # Zero-out the assigned anchor box to filter the shadowed gt indices
         
     | 
| 324 | 
         
            +
                    is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
         
     | 
| 325 | 
         
            +
                    # Concat the shadowed indices due to overlapping with that out side of
         
     | 
| 326 | 
         
            +
                    #   effective scale. shape: (total_num_ignore, 2)
         
     | 
| 327 | 
         
            +
                    shadowed_gt_inds = torch.cat(
         
     | 
| 328 | 
         
            +
                        (shadowed_gt_inds, torch.nonzero(
         
     | 
| 329 | 
         
            +
                            is_bbox_in_gt_core, as_tuple=False)),
         
     | 
| 330 | 
         
            +
                        dim=0)
         
     | 
| 331 | 
         
            +
                    # `is_bbox_in_gt_core` should be changed back to keep arguments intact.
         
     | 
| 332 | 
         
            +
                    is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
         
     | 
| 333 | 
         
            +
                    # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
         
     | 
| 334 | 
         
            +
                    if shadowed_gt_inds.numel() > 0:
         
     | 
| 335 | 
         
            +
                        shadowed_gt_inds[:, 1] += 1
         
     | 
| 336 | 
         
            +
                    return assigned_gt_inds, shadowed_gt_inds
         
     | 
    	
        mmdet/core/bbox/assigners/grid_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,156 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 6 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 7 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 11 | 
         
            +
            class GridAssigner(BaseAssigner):
         
     | 
| 12 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Each proposals will be assigned with `-1`, `0`, or a positive integer
         
     | 
| 15 | 
         
            +
                indicating the ground truth index.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                - -1: don't care
         
     | 
| 18 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 19 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    pos_iou_thr (float): IoU threshold for positive bboxes.
         
     | 
| 23 | 
         
            +
                    neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
         
     | 
| 24 | 
         
            +
                    min_pos_iou (float): Minimum iou for a bbox to be considered as a
         
     | 
| 25 | 
         
            +
                        positive bbox. Positive samples can have smaller IoU than
         
     | 
| 26 | 
         
            +
                        pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
         
     | 
| 27 | 
         
            +
                    gt_max_assign_all (bool): Whether to assign all bboxes with the same
         
     | 
| 28 | 
         
            +
                        highest overlap with some gt to that gt.
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def __init__(self,
         
     | 
| 32 | 
         
            +
                             pos_iou_thr,
         
     | 
| 33 | 
         
            +
                             neg_iou_thr,
         
     | 
| 34 | 
         
            +
                             min_pos_iou=.0,
         
     | 
| 35 | 
         
            +
                             gt_max_assign_all=True,
         
     | 
| 36 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 37 | 
         
            +
                    self.pos_iou_thr = pos_iou_thr
         
     | 
| 38 | 
         
            +
                    self.neg_iou_thr = neg_iou_thr
         
     | 
| 39 | 
         
            +
                    self.min_pos_iou = min_pos_iou
         
     | 
| 40 | 
         
            +
                    self.gt_max_assign_all = gt_max_assign_all
         
     | 
| 41 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
         
     | 
| 44 | 
         
            +
                    """Assign gt to bboxes. The process is very much like the max iou
         
     | 
| 45 | 
         
            +
                    assigner, except that positive samples are constrained within the cell
         
     | 
| 46 | 
         
            +
                    that the gt boxes fell in.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    This method assign a gt bbox to every bbox (proposal/anchor), each bbox
         
     | 
| 49 | 
         
            +
                    will be assigned with -1, 0, or a positive number. -1 means don't care,
         
     | 
| 50 | 
         
            +
                    0 means negative sample, positive number is the index (1-based) of
         
     | 
| 51 | 
         
            +
                    assigned gt.
         
     | 
| 52 | 
         
            +
                    The assignment is done in following steps, the order matters.
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    1. assign every bbox to -1
         
     | 
| 55 | 
         
            +
                    2. assign proposals whose iou with all gts <= neg_iou_thr to 0
         
     | 
| 56 | 
         
            +
                    3. for each bbox within a cell, if the iou with its nearest gt >
         
     | 
| 57 | 
         
            +
                        pos_iou_thr and the center of that gt falls inside the cell,
         
     | 
| 58 | 
         
            +
                        assign it to that bbox
         
     | 
| 59 | 
         
            +
                    4. for each gt bbox, assign its nearest proposals within the cell the
         
     | 
| 60 | 
         
            +
                        gt bbox falls in to itself.
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    Args:
         
     | 
| 63 | 
         
            +
                        bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
         
     | 
| 64 | 
         
            +
                        box_responsible_flags (Tensor): flag to indicate whether box is
         
     | 
| 65 | 
         
            +
                            responsible for prediction, shape(n, )
         
     | 
| 66 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 67 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    Returns:
         
     | 
| 70 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 71 | 
         
            +
                    """
         
     | 
| 72 | 
         
            +
                    num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # compute iou between all gt and bboxes
         
     | 
| 75 | 
         
            +
                    overlaps = self.iou_calculator(gt_bboxes, bboxes)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    # 1. assign -1 by default
         
     | 
| 78 | 
         
            +
                    assigned_gt_inds = overlaps.new_full((num_bboxes, ),
         
     | 
| 79 | 
         
            +
                                                         -1,
         
     | 
| 80 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    if num_gts == 0 or num_bboxes == 0:
         
     | 
| 83 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 84 | 
         
            +
                        max_overlaps = overlaps.new_zeros((num_bboxes, ))
         
     | 
| 85 | 
         
            +
                        if num_gts == 0:
         
     | 
| 86 | 
         
            +
                            # No truth, assign everything to background
         
     | 
| 87 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 88 | 
         
            +
                        if gt_labels is None:
         
     | 
| 89 | 
         
            +
                            assigned_labels = None
         
     | 
| 90 | 
         
            +
                        else:
         
     | 
| 91 | 
         
            +
                            assigned_labels = overlaps.new_full((num_bboxes, ),
         
     | 
| 92 | 
         
            +
                                                                -1,
         
     | 
| 93 | 
         
            +
                                                                dtype=torch.long)
         
     | 
| 94 | 
         
            +
                        return AssignResult(
         
     | 
| 95 | 
         
            +
                            num_gts,
         
     | 
| 96 | 
         
            +
                            assigned_gt_inds,
         
     | 
| 97 | 
         
            +
                            max_overlaps,
         
     | 
| 98 | 
         
            +
                            labels=assigned_labels)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    # 2. assign negative: below
         
     | 
| 101 | 
         
            +
                    # for each anchor, which gt best overlaps with it
         
     | 
| 102 | 
         
            +
                    # for each anchor, the max iou of all gts
         
     | 
| 103 | 
         
            +
                    # shape of max_overlaps == argmax_overlaps == num_bboxes
         
     | 
| 104 | 
         
            +
                    max_overlaps, argmax_overlaps = overlaps.max(dim=0)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    if isinstance(self.neg_iou_thr, float):
         
     | 
| 107 | 
         
            +
                        assigned_gt_inds[(max_overlaps >= 0)
         
     | 
| 108 | 
         
            +
                                         & (max_overlaps <= self.neg_iou_thr)] = 0
         
     | 
| 109 | 
         
            +
                    elif isinstance(self.neg_iou_thr, (tuple, list)):
         
     | 
| 110 | 
         
            +
                        assert len(self.neg_iou_thr) == 2
         
     | 
| 111 | 
         
            +
                        assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
         
     | 
| 112 | 
         
            +
                                         & (max_overlaps <= self.neg_iou_thr[1])] = 0
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    # 3. assign positive: falls into responsible cell and above
         
     | 
| 115 | 
         
            +
                    # positive IOU threshold, the order matters.
         
     | 
| 116 | 
         
            +
                    # the prior condition of comparison is to filter out all
         
     | 
| 117 | 
         
            +
                    # unrelated anchors, i.e. not box_responsible_flags
         
     | 
| 118 | 
         
            +
                    overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    # calculate max_overlaps again, but this time we only consider IOUs
         
     | 
| 121 | 
         
            +
                    # for anchors responsible for prediction
         
     | 
| 122 | 
         
            +
                    max_overlaps, argmax_overlaps = overlaps.max(dim=0)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # for each gt, which anchor best overlaps with it
         
     | 
| 125 | 
         
            +
                    # for each gt, the max iou of all proposals
         
     | 
| 126 | 
         
            +
                    # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
         
     | 
| 127 | 
         
            +
                    gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    pos_inds = (max_overlaps >
         
     | 
| 130 | 
         
            +
                                self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
         
     | 
| 131 | 
         
            +
                    assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    # 4. assign positive to max overlapped anchors within responsible cell
         
     | 
| 134 | 
         
            +
                    for i in range(num_gts):
         
     | 
| 135 | 
         
            +
                        if gt_max_overlaps[i] > self.min_pos_iou:
         
     | 
| 136 | 
         
            +
                            if self.gt_max_assign_all:
         
     | 
| 137 | 
         
            +
                                max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
         
     | 
| 138 | 
         
            +
                                     box_responsible_flags.type(torch.bool)
         
     | 
| 139 | 
         
            +
                                assigned_gt_inds[max_iou_inds] = i + 1
         
     | 
| 140 | 
         
            +
                            elif box_responsible_flags[gt_argmax_overlaps[i]]:
         
     | 
| 141 | 
         
            +
                                assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    # assign labels of positive anchors
         
     | 
| 144 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 145 | 
         
            +
                        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
         
     | 
| 146 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 147 | 
         
            +
                            assigned_gt_inds > 0, as_tuple=False).squeeze()
         
     | 
| 148 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 149 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[
         
     | 
| 150 | 
         
            +
                                assigned_gt_inds[pos_inds] - 1]
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    else:
         
     | 
| 153 | 
         
            +
                        assigned_labels = None
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    return AssignResult(
         
     | 
| 156 | 
         
            +
                        num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/hungarian_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,139 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from scipy.optimize import linear_sum_assignment
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 6 | 
         
            +
            from ..match_costs import build_match_cost
         
     | 
| 7 | 
         
            +
            from ..transforms import bbox_cxcywh_to_xyxy
         
     | 
| 8 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 9 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 13 | 
         
            +
            class HungarianAssigner(BaseAssigner):
         
     | 
| 14 | 
         
            +
                """Computes one-to-one matching between predictions and ground truth.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                This class computes an assignment between the targets and the predictions
         
     | 
| 17 | 
         
            +
                based on the costs. The costs are weighted sum of three components:
         
     | 
| 18 | 
         
            +
                classification cost, regression L1 cost and regression iou cost. The
         
     | 
| 19 | 
         
            +
                targets don't include the no_object, so generally there are more
         
     | 
| 20 | 
         
            +
                predictions than targets. After the one-to-one matching, the un-matched
         
     | 
| 21 | 
         
            +
                are treated as backgrounds. Thus each query prediction will be assigned
         
     | 
| 22 | 
         
            +
                with `0` or a positive integer indicating the ground truth index:
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 25 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Args:
         
     | 
| 28 | 
         
            +
                    cls_weight (int | float, optional): The scale factor for classification
         
     | 
| 29 | 
         
            +
                        cost. Default 1.0.
         
     | 
| 30 | 
         
            +
                    bbox_weight (int | float, optional): The scale factor for regression
         
     | 
| 31 | 
         
            +
                        L1 cost. Default 1.0.
         
     | 
| 32 | 
         
            +
                    iou_weight (int | float, optional): The scale factor for regression
         
     | 
| 33 | 
         
            +
                        iou cost. Default 1.0.
         
     | 
| 34 | 
         
            +
                    iou_calculator (dict | optional): The config for the iou calculation.
         
     | 
| 35 | 
         
            +
                        Default type `BboxOverlaps2D`.
         
     | 
| 36 | 
         
            +
                    iou_mode (str | optional): "iou" (intersection over union), "iof"
         
     | 
| 37 | 
         
            +
                            (intersection over foreground), or "giou" (generalized
         
     | 
| 38 | 
         
            +
                            intersection over union). Default "giou".
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __init__(self,
         
     | 
| 42 | 
         
            +
                             cls_cost=dict(type='ClassificationCost', weight=1.),
         
     | 
| 43 | 
         
            +
                             reg_cost=dict(type='BBoxL1Cost', weight=1.0),
         
     | 
| 44 | 
         
            +
                             iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
         
     | 
| 45 | 
         
            +
                    self.cls_cost = build_match_cost(cls_cost)
         
     | 
| 46 | 
         
            +
                    self.reg_cost = build_match_cost(reg_cost)
         
     | 
| 47 | 
         
            +
                    self.iou_cost = build_match_cost(iou_cost)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def assign(self,
         
     | 
| 50 | 
         
            +
                           bbox_pred,
         
     | 
| 51 | 
         
            +
                           cls_pred,
         
     | 
| 52 | 
         
            +
                           gt_bboxes,
         
     | 
| 53 | 
         
            +
                           gt_labels,
         
     | 
| 54 | 
         
            +
                           img_meta,
         
     | 
| 55 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 56 | 
         
            +
                           eps=1e-7):
         
     | 
| 57 | 
         
            +
                    """Computes one-to-one matching based on the weighted costs.
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    This method assign each query prediction to a ground truth or
         
     | 
| 60 | 
         
            +
                    background. The `assigned_gt_inds` with -1 means don't care,
         
     | 
| 61 | 
         
            +
                    0 means negative sample, and positive number is the index (1-based)
         
     | 
| 62 | 
         
            +
                    of assigned gt.
         
     | 
| 63 | 
         
            +
                    The assignment is done in the following steps, the order matters.
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    1. assign every prediction to -1
         
     | 
| 66 | 
         
            +
                    2. compute the weighted costs
         
     | 
| 67 | 
         
            +
                    3. do Hungarian matching on CPU based on the costs
         
     | 
| 68 | 
         
            +
                    4. assign all to 0 (background) first, then for each matched pair
         
     | 
| 69 | 
         
            +
                       between predictions and gts, treat this prediction as foreground
         
     | 
| 70 | 
         
            +
                       and assign the corresponding gt index (plus 1) to it.
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    Args:
         
     | 
| 73 | 
         
            +
                        bbox_pred (Tensor): Predicted boxes with normalized coordinates
         
     | 
| 74 | 
         
            +
                            (cx, cy, w, h), which are all in range [0, 1]. Shape
         
     | 
| 75 | 
         
            +
                            [num_query, 4].
         
     | 
| 76 | 
         
            +
                        cls_pred (Tensor): Predicted classification logits, shape
         
     | 
| 77 | 
         
            +
                            [num_query, num_class].
         
     | 
| 78 | 
         
            +
                        gt_bboxes (Tensor): Ground truth boxes with unnormalized
         
     | 
| 79 | 
         
            +
                            coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
         
     | 
| 80 | 
         
            +
                        gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
         
     | 
| 81 | 
         
            +
                        img_meta (dict): Meta information for current image.
         
     | 
| 82 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 83 | 
         
            +
                            labelled as `ignored`. Default None.
         
     | 
| 84 | 
         
            +
                        eps (int | float, optional): A value added to the denominator for
         
     | 
| 85 | 
         
            +
                            numerical stability. Default 1e-7.
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    Returns:
         
     | 
| 88 | 
         
            +
                        :obj:`AssignResult`: The assigned result.
         
     | 
| 89 | 
         
            +
                    """
         
     | 
| 90 | 
         
            +
                    assert gt_bboxes_ignore is None, \
         
     | 
| 91 | 
         
            +
                        'Only case when gt_bboxes_ignore is None is supported.'
         
     | 
| 92 | 
         
            +
                    num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    # 1. assign -1 by default
         
     | 
| 95 | 
         
            +
                    assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
         
     | 
| 96 | 
         
            +
                                                          -1,
         
     | 
| 97 | 
         
            +
                                                          dtype=torch.long)
         
     | 
| 98 | 
         
            +
                    assigned_labels = bbox_pred.new_full((num_bboxes, ),
         
     | 
| 99 | 
         
            +
                                                         -1,
         
     | 
| 100 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 101 | 
         
            +
                    if num_gts == 0 or num_bboxes == 0:
         
     | 
| 102 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 103 | 
         
            +
                        if num_gts == 0:
         
     | 
| 104 | 
         
            +
                            # No ground truth, assign all to background
         
     | 
| 105 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 106 | 
         
            +
                        return AssignResult(
         
     | 
| 107 | 
         
            +
                            num_gts, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
| 108 | 
         
            +
                    img_h, img_w, _ = img_meta['img_shape']
         
     | 
| 109 | 
         
            +
                    factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
         
     | 
| 110 | 
         
            +
                                                   img_h]).unsqueeze(0)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    # 2. compute the weighted costs
         
     | 
| 113 | 
         
            +
                    # classification and bboxcost.
         
     | 
| 114 | 
         
            +
                    cls_cost = self.cls_cost(cls_pred, gt_labels)
         
     | 
| 115 | 
         
            +
                    # regression L1 cost
         
     | 
| 116 | 
         
            +
                    normalize_gt_bboxes = gt_bboxes / factor
         
     | 
| 117 | 
         
            +
                    reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
         
     | 
| 118 | 
         
            +
                    # regression iou cost, defaultly giou is used in official DETR.
         
     | 
| 119 | 
         
            +
                    bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
         
     | 
| 120 | 
         
            +
                    iou_cost = self.iou_cost(bboxes, gt_bboxes)
         
     | 
| 121 | 
         
            +
                    # weighted sum of above three costs
         
     | 
| 122 | 
         
            +
                    cost = cls_cost + reg_cost + iou_cost
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # 3. do Hungarian matching on CPU using linear_sum_assignment
         
     | 
| 125 | 
         
            +
                    cost = cost.detach().cpu()
         
     | 
| 126 | 
         
            +
                    matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
         
     | 
| 127 | 
         
            +
                    matched_row_inds = torch.from_numpy(matched_row_inds).to(
         
     | 
| 128 | 
         
            +
                        bbox_pred.device)
         
     | 
| 129 | 
         
            +
                    matched_col_inds = torch.from_numpy(matched_col_inds).to(
         
     | 
| 130 | 
         
            +
                        bbox_pred.device)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # 4. assign backgrounds and foregrounds
         
     | 
| 133 | 
         
            +
                    # assign all indices to backgrounds first
         
     | 
| 134 | 
         
            +
                    assigned_gt_inds[:] = 0
         
     | 
| 135 | 
         
            +
                    # assign foregrounds based on matching results
         
     | 
| 136 | 
         
            +
                    assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
         
     | 
| 137 | 
         
            +
                    assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
         
     | 
| 138 | 
         
            +
                    return AssignResult(
         
     | 
| 139 | 
         
            +
                        num_gts, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/mask_hungarian_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from scipy.optimize import linear_sum_assignment
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from mmdet.core.bbox.builder import BBOX_ASSIGNERS
         
     | 
| 6 | 
         
            +
            from mmdet.core.bbox.match_costs.builder import build_match_cost
         
     | 
| 7 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 8 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 12 | 
         
            +
            class MaskHungarianAssigner(BaseAssigner):
         
     | 
| 13 | 
         
            +
                """Computes one-to-one matching between predictions and ground truth for
         
     | 
| 14 | 
         
            +
                mask.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                This class computes an assignment between the targets and the predictions
         
     | 
| 17 | 
         
            +
                based on the costs. The costs are weighted sum of three components:
         
     | 
| 18 | 
         
            +
                classification cost, mask focal cost and mask dice cost. The
         
     | 
| 19 | 
         
            +
                targets don't include the no_object, so generally there are more
         
     | 
| 20 | 
         
            +
                predictions than targets. After the one-to-one matching, the un-matched
         
     | 
| 21 | 
         
            +
                are treated as backgrounds. Thus each query prediction will be assigned
         
     | 
| 22 | 
         
            +
                with `0` or a positive integer indicating the ground truth index:
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 25 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Args:
         
     | 
| 28 | 
         
            +
                    cls_cost (:obj:`mmcv.ConfigDict` | dict): Classification cost config.
         
     | 
| 29 | 
         
            +
                    mask_cost (:obj:`mmcv.ConfigDict` | dict): Mask cost config.
         
     | 
| 30 | 
         
            +
                    dice_cost (:obj:`mmcv.ConfigDict` | dict): Dice cost config.
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def __init__(self,
         
     | 
| 34 | 
         
            +
                             cls_cost=dict(type='ClassificationCost', weight=1.0),
         
     | 
| 35 | 
         
            +
                             mask_cost=dict(
         
     | 
| 36 | 
         
            +
                                 type='FocalLossCost', weight=1.0, binary_input=True),
         
     | 
| 37 | 
         
            +
                             dice_cost=dict(type='DiceCost', weight=1.0)):
         
     | 
| 38 | 
         
            +
                    self.cls_cost = build_match_cost(cls_cost)
         
     | 
| 39 | 
         
            +
                    self.mask_cost = build_match_cost(mask_cost)
         
     | 
| 40 | 
         
            +
                    self.dice_cost = build_match_cost(dice_cost)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def assign(self,
         
     | 
| 43 | 
         
            +
                           cls_pred,
         
     | 
| 44 | 
         
            +
                           mask_pred,
         
     | 
| 45 | 
         
            +
                           gt_labels,
         
     | 
| 46 | 
         
            +
                           gt_mask,
         
     | 
| 47 | 
         
            +
                           img_meta,
         
     | 
| 48 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 49 | 
         
            +
                           eps=1e-7):
         
     | 
| 50 | 
         
            +
                    """Computes one-to-one matching based on the weighted costs.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    Args:
         
     | 
| 53 | 
         
            +
                        cls_pred (Tensor | None): Class prediction in shape
         
     | 
| 54 | 
         
            +
                            (num_query, cls_out_channels).
         
     | 
| 55 | 
         
            +
                        mask_pred (Tensor): Mask prediction in shape (num_query, H, W).
         
     | 
| 56 | 
         
            +
                        gt_labels (Tensor): Label of 'gt_mask'in shape = (num_gt, ).
         
     | 
| 57 | 
         
            +
                        gt_mask (Tensor): Ground truth mask in shape = (num_gt, H, W).
         
     | 
| 58 | 
         
            +
                        img_meta (dict): Meta information for current image.
         
     | 
| 59 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 60 | 
         
            +
                            labelled as `ignored`. Default None.
         
     | 
| 61 | 
         
            +
                        eps (int | float, optional): A value added to the denominator for
         
     | 
| 62 | 
         
            +
                            numerical stability. Default 1e-7.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    Returns:
         
     | 
| 65 | 
         
            +
                        :obj:`AssignResult`: The assigned result.
         
     | 
| 66 | 
         
            +
                    """
         
     | 
| 67 | 
         
            +
                    assert gt_bboxes_ignore is None, \
         
     | 
| 68 | 
         
            +
                        'Only case when gt_bboxes_ignore is None is supported.'
         
     | 
| 69 | 
         
            +
                    # K-Net sometimes passes cls_pred=None to this assigner.
         
     | 
| 70 | 
         
            +
                    # So we should use the shape of mask_pred
         
     | 
| 71 | 
         
            +
                    num_gt, num_query = gt_labels.shape[0], mask_pred.shape[0]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    # 1. assign -1 by default
         
     | 
| 74 | 
         
            +
                    assigned_gt_inds = mask_pred.new_full((num_query, ),
         
     | 
| 75 | 
         
            +
                                                          -1,
         
     | 
| 76 | 
         
            +
                                                          dtype=torch.long)
         
     | 
| 77 | 
         
            +
                    assigned_labels = mask_pred.new_full((num_query, ),
         
     | 
| 78 | 
         
            +
                                                         -1,
         
     | 
| 79 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 80 | 
         
            +
                    if num_gt == 0 or num_query == 0:
         
     | 
| 81 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 82 | 
         
            +
                        if num_gt == 0:
         
     | 
| 83 | 
         
            +
                            # No ground truth, assign all to background
         
     | 
| 84 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 85 | 
         
            +
                        return AssignResult(
         
     | 
| 86 | 
         
            +
                            num_gt, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # 2. compute the weighted costs
         
     | 
| 89 | 
         
            +
                    # classification and maskcost.
         
     | 
| 90 | 
         
            +
                    if self.cls_cost.weight != 0 and cls_pred is not None:
         
     | 
| 91 | 
         
            +
                        cls_cost = self.cls_cost(cls_pred, gt_labels)
         
     | 
| 92 | 
         
            +
                    else:
         
     | 
| 93 | 
         
            +
                        cls_cost = 0
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    if self.mask_cost.weight != 0:
         
     | 
| 96 | 
         
            +
                        # mask_pred shape = [num_query, h, w]
         
     | 
| 97 | 
         
            +
                        # gt_mask shape = [num_gt, h, w]
         
     | 
| 98 | 
         
            +
                        # mask_cost shape = [num_query, num_gt]
         
     | 
| 99 | 
         
            +
                        mask_cost = self.mask_cost(mask_pred, gt_mask)
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        mask_cost = 0
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    if self.dice_cost.weight != 0:
         
     | 
| 104 | 
         
            +
                        dice_cost = self.dice_cost(mask_pred, gt_mask)
         
     | 
| 105 | 
         
            +
                    else:
         
     | 
| 106 | 
         
            +
                        dice_cost = 0
         
     | 
| 107 | 
         
            +
                    cost = cls_cost + mask_cost + dice_cost
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    # 3. do Hungarian matching on CPU using linear_sum_assignment
         
     | 
| 110 | 
         
            +
                    cost = cost.detach().cpu()
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
         
     | 
| 113 | 
         
            +
                    matched_row_inds = torch.from_numpy(matched_row_inds).to(
         
     | 
| 114 | 
         
            +
                        mask_pred.device)
         
     | 
| 115 | 
         
            +
                    matched_col_inds = torch.from_numpy(matched_col_inds).to(
         
     | 
| 116 | 
         
            +
                        mask_pred.device)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    # 4. assign backgrounds and foregrounds
         
     | 
| 119 | 
         
            +
                    # assign all indices to backgrounds first
         
     | 
| 120 | 
         
            +
                    assigned_gt_inds[:] = 0
         
     | 
| 121 | 
         
            +
                    # assign foregrounds based on matching results
         
     | 
| 122 | 
         
            +
                    assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
         
     | 
| 123 | 
         
            +
                    assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
         
     | 
| 124 | 
         
            +
                    return AssignResult(
         
     | 
| 125 | 
         
            +
                        num_gt, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/max_iou_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,218 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 6 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 7 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 11 | 
         
            +
            class MaxIoUAssigner(BaseAssigner):
         
     | 
| 12 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Each proposals will be assigned with `-1`, or a semi-positive integer
         
     | 
| 15 | 
         
            +
                indicating the ground truth index.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                - -1: negative sample, no assigned gt
         
     | 
| 18 | 
         
            +
                - semi-positive integer: positive sample, index (0-based) of assigned gt
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                Args:
         
     | 
| 21 | 
         
            +
                    pos_iou_thr (float): IoU threshold for positive bboxes.
         
     | 
| 22 | 
         
            +
                    neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
         
     | 
| 23 | 
         
            +
                    min_pos_iou (float): Minimum iou for a bbox to be considered as a
         
     | 
| 24 | 
         
            +
                        positive bbox. Positive samples can have smaller IoU than
         
     | 
| 25 | 
         
            +
                        pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
         
     | 
| 26 | 
         
            +
                        `min_pos_iou` is set to avoid assigning bboxes that have extremely
         
     | 
| 27 | 
         
            +
                        small iou with GT as positive samples. It brings about 0.3 mAP
         
     | 
| 28 | 
         
            +
                        improvements in 1x schedule but does not affect the performance of
         
     | 
| 29 | 
         
            +
                        3x schedule. More comparisons can be found in
         
     | 
| 30 | 
         
            +
                        `PR #7464 <https://github.com/open-mmlab/mmdetection/pull/7464>`_.
         
     | 
| 31 | 
         
            +
                    gt_max_assign_all (bool): Whether to assign all bboxes with the same
         
     | 
| 32 | 
         
            +
                        highest overlap with some gt to that gt.
         
     | 
| 33 | 
         
            +
                    ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
         
     | 
| 34 | 
         
            +
                        `gt_bboxes_ignore` is specified). Negative values mean not
         
     | 
| 35 | 
         
            +
                        ignoring any bboxes.
         
     | 
| 36 | 
         
            +
                    ignore_wrt_candidates (bool): Whether to compute the iof between
         
     | 
| 37 | 
         
            +
                        `bboxes` and `gt_bboxes_ignore`, or the contrary.
         
     | 
| 38 | 
         
            +
                    match_low_quality (bool): Whether to allow low quality matches. This is
         
     | 
| 39 | 
         
            +
                        usually allowed for RPN and single stage detectors, but not allowed
         
     | 
| 40 | 
         
            +
                        in the second stage. Details are demonstrated in Step 4.
         
     | 
| 41 | 
         
            +
                    gpu_assign_thr (int): The upper bound of the number of GT for GPU
         
     | 
| 42 | 
         
            +
                        assign. When the number of gt is above this threshold, will assign
         
     | 
| 43 | 
         
            +
                        on CPU device. Negative values mean not assign on CPU.
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def __init__(self,
         
     | 
| 47 | 
         
            +
                             pos_iou_thr,
         
     | 
| 48 | 
         
            +
                             neg_iou_thr,
         
     | 
| 49 | 
         
            +
                             min_pos_iou=.0,
         
     | 
| 50 | 
         
            +
                             gt_max_assign_all=True,
         
     | 
| 51 | 
         
            +
                             ignore_iof_thr=-1,
         
     | 
| 52 | 
         
            +
                             ignore_wrt_candidates=True,
         
     | 
| 53 | 
         
            +
                             match_low_quality=True,
         
     | 
| 54 | 
         
            +
                             gpu_assign_thr=-1,
         
     | 
| 55 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 56 | 
         
            +
                    self.pos_iou_thr = pos_iou_thr
         
     | 
| 57 | 
         
            +
                    self.neg_iou_thr = neg_iou_thr
         
     | 
| 58 | 
         
            +
                    self.min_pos_iou = min_pos_iou
         
     | 
| 59 | 
         
            +
                    self.gt_max_assign_all = gt_max_assign_all
         
     | 
| 60 | 
         
            +
                    self.ignore_iof_thr = ignore_iof_thr
         
     | 
| 61 | 
         
            +
                    self.ignore_wrt_candidates = ignore_wrt_candidates
         
     | 
| 62 | 
         
            +
                    self.gpu_assign_thr = gpu_assign_thr
         
     | 
| 63 | 
         
            +
                    self.match_low_quality = match_low_quality
         
     | 
| 64 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
         
     | 
| 67 | 
         
            +
                    """Assign gt to bboxes.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    This method assign a gt bbox to every bbox (proposal/anchor), each bbox
         
     | 
| 70 | 
         
            +
                    will be assigned with -1, or a semi-positive number. -1 means negative
         
     | 
| 71 | 
         
            +
                    sample, semi-positive number is the index (0-based) of assigned gt.
         
     | 
| 72 | 
         
            +
                    The assignment is done in following steps, the order matters.
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    1. assign every bbox to the background
         
     | 
| 75 | 
         
            +
                    2. assign proposals whose iou with all gts < neg_iou_thr to 0
         
     | 
| 76 | 
         
            +
                    3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
         
     | 
| 77 | 
         
            +
                       assign it to that bbox
         
     | 
| 78 | 
         
            +
                    4. for each gt bbox, assign its nearest proposals (may be more than
         
     | 
| 79 | 
         
            +
                       one) to itself
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    Args:
         
     | 
| 82 | 
         
            +
                        bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
         
     | 
| 83 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 84 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 85 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 86 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    Returns:
         
     | 
| 89 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    Example:
         
     | 
| 92 | 
         
            +
                        >>> self = MaxIoUAssigner(0.5, 0.5)
         
     | 
| 93 | 
         
            +
                        >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
         
     | 
| 94 | 
         
            +
                        >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
         
     | 
| 95 | 
         
            +
                        >>> assign_result = self.assign(bboxes, gt_bboxes)
         
     | 
| 96 | 
         
            +
                        >>> expected_gt_inds = torch.LongTensor([1, 0])
         
     | 
| 97 | 
         
            +
                        >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
         
     | 
| 98 | 
         
            +
                    """
         
     | 
| 99 | 
         
            +
                    assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
         
     | 
| 100 | 
         
            +
                        gt_bboxes.shape[0] > self.gpu_assign_thr) else False
         
     | 
| 101 | 
         
            +
                    # compute overlap and assign gt on CPU when number of GT is large
         
     | 
| 102 | 
         
            +
                    if assign_on_cpu:
         
     | 
| 103 | 
         
            +
                        device = bboxes.device
         
     | 
| 104 | 
         
            +
                        bboxes = bboxes.cpu()
         
     | 
| 105 | 
         
            +
                        gt_bboxes = gt_bboxes.cpu()
         
     | 
| 106 | 
         
            +
                        if gt_bboxes_ignore is not None:
         
     | 
| 107 | 
         
            +
                            gt_bboxes_ignore = gt_bboxes_ignore.cpu()
         
     | 
| 108 | 
         
            +
                        if gt_labels is not None:
         
     | 
| 109 | 
         
            +
                            gt_labels = gt_labels.cpu()
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    overlaps = self.iou_calculator(gt_bboxes, bboxes)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
         
     | 
| 114 | 
         
            +
                            and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
         
     | 
| 115 | 
         
            +
                        if self.ignore_wrt_candidates:
         
     | 
| 116 | 
         
            +
                            ignore_overlaps = self.iou_calculator(
         
     | 
| 117 | 
         
            +
                                bboxes, gt_bboxes_ignore, mode='iof')
         
     | 
| 118 | 
         
            +
                            ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
         
     | 
| 119 | 
         
            +
                        else:
         
     | 
| 120 | 
         
            +
                            ignore_overlaps = self.iou_calculator(
         
     | 
| 121 | 
         
            +
                                gt_bboxes_ignore, bboxes, mode='iof')
         
     | 
| 122 | 
         
            +
                            ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
         
     | 
| 123 | 
         
            +
                        overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
         
     | 
| 126 | 
         
            +
                    if assign_on_cpu:
         
     | 
| 127 | 
         
            +
                        assign_result.gt_inds = assign_result.gt_inds.to(device)
         
     | 
| 128 | 
         
            +
                        assign_result.max_overlaps = assign_result.max_overlaps.to(device)
         
     | 
| 129 | 
         
            +
                        if assign_result.labels is not None:
         
     | 
| 130 | 
         
            +
                            assign_result.labels = assign_result.labels.to(device)
         
     | 
| 131 | 
         
            +
                    return assign_result
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def assign_wrt_overlaps(self, overlaps, gt_labels=None):
         
     | 
| 134 | 
         
            +
                    """Assign w.r.t. the overlaps of bboxes with gts.
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    Args:
         
     | 
| 137 | 
         
            +
                        overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
         
     | 
| 138 | 
         
            +
                            shape(k, n).
         
     | 
| 139 | 
         
            +
                        gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    Returns:
         
     | 
| 142 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 143 | 
         
            +
                    """
         
     | 
| 144 | 
         
            +
                    num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    # 1. assign -1 by default
         
     | 
| 147 | 
         
            +
                    assigned_gt_inds = overlaps.new_full((num_bboxes, ),
         
     | 
| 148 | 
         
            +
                                                         -1,
         
     | 
| 149 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if num_gts == 0 or num_bboxes == 0:
         
     | 
| 152 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 153 | 
         
            +
                        max_overlaps = overlaps.new_zeros((num_bboxes, ))
         
     | 
| 154 | 
         
            +
                        if num_gts == 0:
         
     | 
| 155 | 
         
            +
                            # No truth, assign everything to background
         
     | 
| 156 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 157 | 
         
            +
                        if gt_labels is None:
         
     | 
| 158 | 
         
            +
                            assigned_labels = None
         
     | 
| 159 | 
         
            +
                        else:
         
     | 
| 160 | 
         
            +
                            assigned_labels = overlaps.new_full((num_bboxes, ),
         
     | 
| 161 | 
         
            +
                                                                -1,
         
     | 
| 162 | 
         
            +
                                                                dtype=torch.long)
         
     | 
| 163 | 
         
            +
                        return AssignResult(
         
     | 
| 164 | 
         
            +
                            num_gts,
         
     | 
| 165 | 
         
            +
                            assigned_gt_inds,
         
     | 
| 166 | 
         
            +
                            max_overlaps,
         
     | 
| 167 | 
         
            +
                            labels=assigned_labels)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    # for each anchor, which gt best overlaps with it
         
     | 
| 170 | 
         
            +
                    # for each anchor, the max iou of all gts
         
     | 
| 171 | 
         
            +
                    max_overlaps, argmax_overlaps = overlaps.max(dim=0)
         
     | 
| 172 | 
         
            +
                    # for each gt, which anchor best overlaps with it
         
     | 
| 173 | 
         
            +
                    # for each gt, the max iou of all proposals
         
     | 
| 174 | 
         
            +
                    gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    # 2. assign negative: below
         
     | 
| 177 | 
         
            +
                    # the negative inds are set to be 0
         
     | 
| 178 | 
         
            +
                    if isinstance(self.neg_iou_thr, float):
         
     | 
| 179 | 
         
            +
                        assigned_gt_inds[(max_overlaps >= 0)
         
     | 
| 180 | 
         
            +
                                         & (max_overlaps < self.neg_iou_thr)] = 0
         
     | 
| 181 | 
         
            +
                    elif isinstance(self.neg_iou_thr, tuple):
         
     | 
| 182 | 
         
            +
                        assert len(self.neg_iou_thr) == 2
         
     | 
| 183 | 
         
            +
                        assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
         
     | 
| 184 | 
         
            +
                                         & (max_overlaps < self.neg_iou_thr[1])] = 0
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # 3. assign positive: above positive IoU threshold
         
     | 
| 187 | 
         
            +
                    pos_inds = max_overlaps >= self.pos_iou_thr
         
     | 
| 188 | 
         
            +
                    assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    if self.match_low_quality:
         
     | 
| 191 | 
         
            +
                        # Low-quality matching will overwrite the assigned_gt_inds assigned
         
     | 
| 192 | 
         
            +
                        # in Step 3. Thus, the assigned gt might not be the best one for
         
     | 
| 193 | 
         
            +
                        # prediction.
         
     | 
| 194 | 
         
            +
                        # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
         
     | 
| 195 | 
         
            +
                        # bbox 1 will be assigned as the best target for bbox A in step 3.
         
     | 
| 196 | 
         
            +
                        # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
         
     | 
| 197 | 
         
            +
                        # assigned_gt_inds will be overwritten to be bbox 2.
         
     | 
| 198 | 
         
            +
                        # This might be the reason that it is not used in ROI Heads.
         
     | 
| 199 | 
         
            +
                        for i in range(num_gts):
         
     | 
| 200 | 
         
            +
                            if gt_max_overlaps[i] >= self.min_pos_iou:
         
     | 
| 201 | 
         
            +
                                if self.gt_max_assign_all:
         
     | 
| 202 | 
         
            +
                                    max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
         
     | 
| 203 | 
         
            +
                                    assigned_gt_inds[max_iou_inds] = i + 1
         
     | 
| 204 | 
         
            +
                                else:
         
     | 
| 205 | 
         
            +
                                    assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 208 | 
         
            +
                        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
         
     | 
| 209 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 210 | 
         
            +
                            assigned_gt_inds > 0, as_tuple=False).squeeze()
         
     | 
| 211 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 212 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[
         
     | 
| 213 | 
         
            +
                                assigned_gt_inds[pos_inds] - 1]
         
     | 
| 214 | 
         
            +
                    else:
         
     | 
| 215 | 
         
            +
                        assigned_labels = None
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    return AssignResult(
         
     | 
| 218 | 
         
            +
                        num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/point_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,134 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 6 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 10 | 
         
            +
            class PointAssigner(BaseAssigner):
         
     | 
| 11 | 
         
            +
                """Assign a corresponding gt bbox or background to each point.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Each proposals will be assigned with `0`, or a positive integer
         
     | 
| 14 | 
         
            +
                indicating the ground truth index.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 17 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __init__(self, scale=4, pos_num=3):
         
     | 
| 21 | 
         
            +
                    self.scale = scale
         
     | 
| 22 | 
         
            +
                    self.pos_num = pos_num
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
         
     | 
| 25 | 
         
            +
                    """Assign gt to points.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    This method assign a gt bbox to every points set, each points set
         
     | 
| 28 | 
         
            +
                    will be assigned with  the background_label (-1), or a label number.
         
     | 
| 29 | 
         
            +
                    -1 is background, and semi-positive number is the index (0-based) of
         
     | 
| 30 | 
         
            +
                    assigned gt.
         
     | 
| 31 | 
         
            +
                    The assignment is done in following steps, the order matters.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    1. assign every points to the background_label (-1)
         
     | 
| 34 | 
         
            +
                    2. A point is assigned to some gt bbox if
         
     | 
| 35 | 
         
            +
                        (i) the point is within the k closest points to the gt bbox
         
     | 
| 36 | 
         
            +
                        (ii) the distance between this point and the gt is smaller than
         
     | 
| 37 | 
         
            +
                            other gt bboxes
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    Args:
         
     | 
| 40 | 
         
            +
                        points (Tensor): points to be assigned, shape(n, 3) while last
         
     | 
| 41 | 
         
            +
                            dimension stands for (x, y, stride).
         
     | 
| 42 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 43 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 44 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 45 | 
         
            +
                            NOTE: currently unused.
         
     | 
| 46 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    Returns:
         
     | 
| 49 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 50 | 
         
            +
                    """
         
     | 
| 51 | 
         
            +
                    num_points = points.shape[0]
         
     | 
| 52 | 
         
            +
                    num_gts = gt_bboxes.shape[0]
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    if num_gts == 0 or num_points == 0:
         
     | 
| 55 | 
         
            +
                        # If no truth assign everything to the background
         
     | 
| 56 | 
         
            +
                        assigned_gt_inds = points.new_full((num_points, ),
         
     | 
| 57 | 
         
            +
                                                           0,
         
     | 
| 58 | 
         
            +
                                                           dtype=torch.long)
         
     | 
| 59 | 
         
            +
                        if gt_labels is None:
         
     | 
| 60 | 
         
            +
                            assigned_labels = None
         
     | 
| 61 | 
         
            +
                        else:
         
     | 
| 62 | 
         
            +
                            assigned_labels = points.new_full((num_points, ),
         
     | 
| 63 | 
         
            +
                                                              -1,
         
     | 
| 64 | 
         
            +
                                                              dtype=torch.long)
         
     | 
| 65 | 
         
            +
                        return AssignResult(
         
     | 
| 66 | 
         
            +
                            num_gts, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    points_xy = points[:, :2]
         
     | 
| 69 | 
         
            +
                    points_stride = points[:, 2]
         
     | 
| 70 | 
         
            +
                    points_lvl = torch.log2(
         
     | 
| 71 | 
         
            +
                        points_stride).int()  # [3...,4...,5...,6...,7...]
         
     | 
| 72 | 
         
            +
                    lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # assign gt box
         
     | 
| 75 | 
         
            +
                    gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
         
     | 
| 76 | 
         
            +
                    gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
         
     | 
| 77 | 
         
            +
                    scale = self.scale
         
     | 
| 78 | 
         
            +
                    gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
         
     | 
| 79 | 
         
            +
                                      torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
         
     | 
| 80 | 
         
            +
                    gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    # stores the assigned gt index of each point
         
     | 
| 83 | 
         
            +
                    assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
         
     | 
| 84 | 
         
            +
                    # stores the assigned gt dist (to this point) of each point
         
     | 
| 85 | 
         
            +
                    assigned_gt_dist = points.new_full((num_points, ), float('inf'))
         
     | 
| 86 | 
         
            +
                    points_range = torch.arange(points.shape[0])
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    for idx in range(num_gts):
         
     | 
| 89 | 
         
            +
                        gt_lvl = gt_bboxes_lvl[idx]
         
     | 
| 90 | 
         
            +
                        # get the index of points in this level
         
     | 
| 91 | 
         
            +
                        lvl_idx = gt_lvl == points_lvl
         
     | 
| 92 | 
         
            +
                        points_index = points_range[lvl_idx]
         
     | 
| 93 | 
         
            +
                        # get the points in this level
         
     | 
| 94 | 
         
            +
                        lvl_points = points_xy[lvl_idx, :]
         
     | 
| 95 | 
         
            +
                        # get the center point of gt
         
     | 
| 96 | 
         
            +
                        gt_point = gt_bboxes_xy[[idx], :]
         
     | 
| 97 | 
         
            +
                        # get width and height of gt
         
     | 
| 98 | 
         
            +
                        gt_wh = gt_bboxes_wh[[idx], :]
         
     | 
| 99 | 
         
            +
                        # compute the distance between gt center and
         
     | 
| 100 | 
         
            +
                        #   all points in this level
         
     | 
| 101 | 
         
            +
                        points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
         
     | 
| 102 | 
         
            +
                        # find the nearest k points to gt center in this level
         
     | 
| 103 | 
         
            +
                        min_dist, min_dist_index = torch.topk(
         
     | 
| 104 | 
         
            +
                            points_gt_dist, self.pos_num, largest=False)
         
     | 
| 105 | 
         
            +
                        # the index of nearest k points to gt center in this level
         
     | 
| 106 | 
         
            +
                        min_dist_points_index = points_index[min_dist_index]
         
     | 
| 107 | 
         
            +
                        # The less_than_recorded_index stores the index
         
     | 
| 108 | 
         
            +
                        #   of min_dist that is less then the assigned_gt_dist. Where
         
     | 
| 109 | 
         
            +
                        #   assigned_gt_dist stores the dist from previous assigned gt
         
     | 
| 110 | 
         
            +
                        #   (if exist) to each point.
         
     | 
| 111 | 
         
            +
                        less_than_recorded_index = min_dist < assigned_gt_dist[
         
     | 
| 112 | 
         
            +
                            min_dist_points_index]
         
     | 
| 113 | 
         
            +
                        # The min_dist_points_index stores the index of points satisfy:
         
     | 
| 114 | 
         
            +
                        #   (1) it is k nearest to current gt center in this level.
         
     | 
| 115 | 
         
            +
                        #   (2) it is closer to current gt center than other gt center.
         
     | 
| 116 | 
         
            +
                        min_dist_points_index = min_dist_points_index[
         
     | 
| 117 | 
         
            +
                            less_than_recorded_index]
         
     | 
| 118 | 
         
            +
                        # assign the result
         
     | 
| 119 | 
         
            +
                        assigned_gt_inds[min_dist_points_index] = idx + 1
         
     | 
| 120 | 
         
            +
                        assigned_gt_dist[min_dist_points_index] = min_dist[
         
     | 
| 121 | 
         
            +
                            less_than_recorded_index]
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 124 | 
         
            +
                        assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
         
     | 
| 125 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 126 | 
         
            +
                            assigned_gt_inds > 0, as_tuple=False).squeeze()
         
     | 
| 127 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 128 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[
         
     | 
| 129 | 
         
            +
                                assigned_gt_inds[pos_inds] - 1]
         
     | 
| 130 | 
         
            +
                    else:
         
     | 
| 131 | 
         
            +
                        assigned_labels = None
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    return AssignResult(
         
     | 
| 134 | 
         
            +
                        num_gts, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/region_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,222 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from mmdet.core import anchor_inside_flags
         
     | 
| 5 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 6 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 7 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def calc_region(bbox, ratio, stride, featmap_size=None):
         
     | 
| 11 | 
         
            +
                """Calculate region of the box defined by the ratio, the ratio is from the
         
     | 
| 12 | 
         
            +
                center of the box to every edge."""
         
     | 
| 13 | 
         
            +
                # project bbox on the feature
         
     | 
| 14 | 
         
            +
                f_bbox = bbox / stride
         
     | 
| 15 | 
         
            +
                x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
         
     | 
| 16 | 
         
            +
                y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
         
     | 
| 17 | 
         
            +
                x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
         
     | 
| 18 | 
         
            +
                y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
         
     | 
| 19 | 
         
            +
                if featmap_size is not None:
         
     | 
| 20 | 
         
            +
                    x1 = x1.clamp(min=0, max=featmap_size[1])
         
     | 
| 21 | 
         
            +
                    y1 = y1.clamp(min=0, max=featmap_size[0])
         
     | 
| 22 | 
         
            +
                    x2 = x2.clamp(min=0, max=featmap_size[1])
         
     | 
| 23 | 
         
            +
                    y2 = y2.clamp(min=0, max=featmap_size[0])
         
     | 
| 24 | 
         
            +
                return (x1, y1, x2, y2)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def anchor_ctr_inside_region_flags(anchors, stride, region):
         
     | 
| 28 | 
         
            +
                """Get the flag indicate whether anchor centers are inside regions."""
         
     | 
| 29 | 
         
            +
                x1, y1, x2, y2 = region
         
     | 
| 30 | 
         
            +
                f_anchors = anchors / stride
         
     | 
| 31 | 
         
            +
                x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
         
     | 
| 32 | 
         
            +
                y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
         
     | 
| 33 | 
         
            +
                flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
         
     | 
| 34 | 
         
            +
                return flags
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 38 | 
         
            +
            class RegionAssigner(BaseAssigner):
         
     | 
| 39 | 
         
            +
                """Assign a corresponding gt bbox or background to each bbox.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                Each proposals will be assigned with `-1`, `0`, or a positive integer
         
     | 
| 42 | 
         
            +
                indicating the ground truth index.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                - -1: don't care
         
     | 
| 45 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 46 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                Args:
         
     | 
| 49 | 
         
            +
                    center_ratio: ratio of the region in the center of the bbox to
         
     | 
| 50 | 
         
            +
                        define positive sample.
         
     | 
| 51 | 
         
            +
                    ignore_ratio: ratio of the region to define ignore samples.
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def __init__(self, center_ratio=0.2, ignore_ratio=0.5):
         
     | 
| 55 | 
         
            +
                    self.center_ratio = center_ratio
         
     | 
| 56 | 
         
            +
                    self.ignore_ratio = ignore_ratio
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def assign(self,
         
     | 
| 59 | 
         
            +
                           mlvl_anchors,
         
     | 
| 60 | 
         
            +
                           mlvl_valid_flags,
         
     | 
| 61 | 
         
            +
                           gt_bboxes,
         
     | 
| 62 | 
         
            +
                           img_meta,
         
     | 
| 63 | 
         
            +
                           featmap_sizes,
         
     | 
| 64 | 
         
            +
                           anchor_scale,
         
     | 
| 65 | 
         
            +
                           anchor_strides,
         
     | 
| 66 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 67 | 
         
            +
                           gt_labels=None,
         
     | 
| 68 | 
         
            +
                           allowed_border=0):
         
     | 
| 69 | 
         
            +
                    """Assign gt to anchors.
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    This method assign a gt bbox to every bbox (proposal/anchor), each bbox
         
     | 
| 72 | 
         
            +
                    will be assigned with -1, 0, or a positive number. -1 means don't care,
         
     | 
| 73 | 
         
            +
                    0 means negative sample, positive number is the index (1-based) of
         
     | 
| 74 | 
         
            +
                    assigned gt.
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    The assignment is done in following steps, and the order matters.
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    1. Assign every anchor to 0 (negative)
         
     | 
| 79 | 
         
            +
                    2. (For each gt_bboxes) Compute ignore flags based on ignore_region
         
     | 
| 80 | 
         
            +
                       then assign -1 to anchors w.r.t. ignore flags
         
     | 
| 81 | 
         
            +
                    3. (For each gt_bboxes) Compute pos flags based on center_region then
         
     | 
| 82 | 
         
            +
                       assign gt_bboxes to anchors w.r.t. pos flags
         
     | 
| 83 | 
         
            +
                    4. (For each gt_bboxes) Compute ignore flags based on adjacent anchor
         
     | 
| 84 | 
         
            +
                       level then assign -1 to anchors w.r.t. ignore flags
         
     | 
| 85 | 
         
            +
                    5. Assign anchor outside of image to -1
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    Args:
         
     | 
| 88 | 
         
            +
                        mlvl_anchors (list[Tensor]): Multi level anchors.
         
     | 
| 89 | 
         
            +
                        mlvl_valid_flags (list[Tensor]): Multi level valid flags.
         
     | 
| 90 | 
         
            +
                        gt_bboxes (Tensor): Ground truth bboxes of image
         
     | 
| 91 | 
         
            +
                        img_meta (dict): Meta info of image.
         
     | 
| 92 | 
         
            +
                        featmap_sizes (list[Tensor]): Feature mapsize each level
         
     | 
| 93 | 
         
            +
                        anchor_scale (int): Scale of the anchor.
         
     | 
| 94 | 
         
            +
                        anchor_strides (list[int]): Stride of the anchor.
         
     | 
| 95 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 96 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 97 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 98 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 99 | 
         
            +
                        allowed_border (int, optional): The border to allow the valid
         
     | 
| 100 | 
         
            +
                            anchor. Defaults to 0.
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    Returns:
         
     | 
| 103 | 
         
            +
                        :obj:`AssignResult`: The assign result.
         
     | 
| 104 | 
         
            +
                    """
         
     | 
| 105 | 
         
            +
                    if gt_bboxes_ignore is not None:
         
     | 
| 106 | 
         
            +
                        raise NotImplementedError
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    num_gts = gt_bboxes.shape[0]
         
     | 
| 109 | 
         
            +
                    num_bboxes = sum(x.shape[0] for x in mlvl_anchors)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    if num_gts == 0 or num_bboxes == 0:
         
     | 
| 112 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 113 | 
         
            +
                        max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
         
     | 
| 114 | 
         
            +
                        assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
         
     | 
| 115 | 
         
            +
                                                               dtype=torch.long)
         
     | 
| 116 | 
         
            +
                        if gt_labels is None:
         
     | 
| 117 | 
         
            +
                            assigned_labels = None
         
     | 
| 118 | 
         
            +
                        else:
         
     | 
| 119 | 
         
            +
                            assigned_labels = gt_bboxes.new_full((num_bboxes, ),
         
     | 
| 120 | 
         
            +
                                                                 -1,
         
     | 
| 121 | 
         
            +
                                                                 dtype=torch.long)
         
     | 
| 122 | 
         
            +
                        return AssignResult(
         
     | 
| 123 | 
         
            +
                            num_gts,
         
     | 
| 124 | 
         
            +
                            assigned_gt_inds,
         
     | 
| 125 | 
         
            +
                            max_overlaps,
         
     | 
| 126 | 
         
            +
                            labels=assigned_labels)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    num_lvls = len(mlvl_anchors)
         
     | 
| 129 | 
         
            +
                    r1 = (1 - self.center_ratio) / 2
         
     | 
| 130 | 
         
            +
                    r2 = (1 - self.ignore_ratio) / 2
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
         
     | 
| 133 | 
         
            +
                                       (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
         
     | 
| 134 | 
         
            +
                    min_anchor_size = scale.new_full(
         
     | 
| 135 | 
         
            +
                        (1, ), float(anchor_scale * anchor_strides[0]))
         
     | 
| 136 | 
         
            +
                    target_lvls = torch.floor(
         
     | 
| 137 | 
         
            +
                        torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
         
     | 
| 138 | 
         
            +
                    target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # 1. assign 0 (negative) by default
         
     | 
| 141 | 
         
            +
                    mlvl_assigned_gt_inds = []
         
     | 
| 142 | 
         
            +
                    mlvl_ignore_flags = []
         
     | 
| 143 | 
         
            +
                    for lvl in range(num_lvls):
         
     | 
| 144 | 
         
            +
                        h, w = featmap_sizes[lvl]
         
     | 
| 145 | 
         
            +
                        assert h * w == mlvl_anchors[lvl].shape[0]
         
     | 
| 146 | 
         
            +
                        assigned_gt_inds = gt_bboxes.new_full((h * w, ),
         
     | 
| 147 | 
         
            +
                                                              0,
         
     | 
| 148 | 
         
            +
                                                              dtype=torch.long)
         
     | 
| 149 | 
         
            +
                        ignore_flags = torch.zeros_like(assigned_gt_inds)
         
     | 
| 150 | 
         
            +
                        mlvl_assigned_gt_inds.append(assigned_gt_inds)
         
     | 
| 151 | 
         
            +
                        mlvl_ignore_flags.append(ignore_flags)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    for gt_id in range(num_gts):
         
     | 
| 154 | 
         
            +
                        lvl = target_lvls[gt_id].item()
         
     | 
| 155 | 
         
            +
                        featmap_size = featmap_sizes[lvl]
         
     | 
| 156 | 
         
            +
                        stride = anchor_strides[lvl]
         
     | 
| 157 | 
         
            +
                        anchors = mlvl_anchors[lvl]
         
     | 
| 158 | 
         
            +
                        gt_bbox = gt_bboxes[gt_id, :4]
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                        # Compute regions
         
     | 
| 161 | 
         
            +
                        ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
         
     | 
| 162 | 
         
            +
                        ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                        # 2. Assign -1 to ignore flags
         
     | 
| 165 | 
         
            +
                        ignore_flags = anchor_ctr_inside_region_flags(
         
     | 
| 166 | 
         
            +
                            anchors, stride, ignore_region)
         
     | 
| 167 | 
         
            +
                        mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                        # 3. Assign gt_bboxes to pos flags
         
     | 
| 170 | 
         
            +
                        pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
         
     | 
| 171 | 
         
            +
                                                                   ctr_region)
         
     | 
| 172 | 
         
            +
                        mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                        # 4. Assign -1 to ignore adjacent lvl
         
     | 
| 175 | 
         
            +
                        if lvl > 0:
         
     | 
| 176 | 
         
            +
                            d_lvl = lvl - 1
         
     | 
| 177 | 
         
            +
                            d_anchors = mlvl_anchors[d_lvl]
         
     | 
| 178 | 
         
            +
                            d_featmap_size = featmap_sizes[d_lvl]
         
     | 
| 179 | 
         
            +
                            d_stride = anchor_strides[d_lvl]
         
     | 
| 180 | 
         
            +
                            d_ignore_region = calc_region(gt_bbox, r2, d_stride,
         
     | 
| 181 | 
         
            +
                                                          d_featmap_size)
         
     | 
| 182 | 
         
            +
                            ignore_flags = anchor_ctr_inside_region_flags(
         
     | 
| 183 | 
         
            +
                                d_anchors, d_stride, d_ignore_region)
         
     | 
| 184 | 
         
            +
                            mlvl_ignore_flags[d_lvl][ignore_flags] = 1
         
     | 
| 185 | 
         
            +
                        if lvl < num_lvls - 1:
         
     | 
| 186 | 
         
            +
                            u_lvl = lvl + 1
         
     | 
| 187 | 
         
            +
                            u_anchors = mlvl_anchors[u_lvl]
         
     | 
| 188 | 
         
            +
                            u_featmap_size = featmap_sizes[u_lvl]
         
     | 
| 189 | 
         
            +
                            u_stride = anchor_strides[u_lvl]
         
     | 
| 190 | 
         
            +
                            u_ignore_region = calc_region(gt_bbox, r2, u_stride,
         
     | 
| 191 | 
         
            +
                                                          u_featmap_size)
         
     | 
| 192 | 
         
            +
                            ignore_flags = anchor_ctr_inside_region_flags(
         
     | 
| 193 | 
         
            +
                                u_anchors, u_stride, u_ignore_region)
         
     | 
| 194 | 
         
            +
                            mlvl_ignore_flags[u_lvl][ignore_flags] = 1
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    # 4. (cont.) Assign -1 to ignore adjacent lvl
         
     | 
| 197 | 
         
            +
                    for lvl in range(num_lvls):
         
     | 
| 198 | 
         
            +
                        ignore_flags = mlvl_ignore_flags[lvl]
         
     | 
| 199 | 
         
            +
                        mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    # 5. Assign -1 to anchor outside of image
         
     | 
| 202 | 
         
            +
                    flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
         
     | 
| 203 | 
         
            +
                    flat_anchors = torch.cat(mlvl_anchors)
         
     | 
| 204 | 
         
            +
                    flat_valid_flags = torch.cat(mlvl_valid_flags)
         
     | 
| 205 | 
         
            +
                    assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
         
     | 
| 206 | 
         
            +
                            flat_valid_flags.shape[0])
         
     | 
| 207 | 
         
            +
                    inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
         
     | 
| 208 | 
         
            +
                                                       img_meta['img_shape'],
         
     | 
| 209 | 
         
            +
                                                       allowed_border)
         
     | 
| 210 | 
         
            +
                    outside_flags = ~inside_flags
         
     | 
| 211 | 
         
            +
                    flat_assigned_gt_inds[outside_flags] = -1
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 214 | 
         
            +
                        assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
         
     | 
| 215 | 
         
            +
                        pos_flags = assigned_gt_inds > 0
         
     | 
| 216 | 
         
            +
                        assigned_labels[pos_flags] = gt_labels[
         
     | 
| 217 | 
         
            +
                            flat_assigned_gt_inds[pos_flags] - 1]
         
     | 
| 218 | 
         
            +
                    else:
         
     | 
| 219 | 
         
            +
                        assigned_labels = None
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    return AssignResult(
         
     | 
| 222 | 
         
            +
                        num_gts, flat_assigned_gt_inds, None, labels=assigned_labels)
         
     | 
    	
        mmdet/core/bbox/assigners/sim_ota_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,257 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 8 | 
         
            +
            from ..iou_calculators import bbox_overlaps
         
     | 
| 9 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 10 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 14 | 
         
            +
            class SimOTAAssigner(BaseAssigner):
         
     | 
| 15 | 
         
            +
                """Computes matching between predictions and ground truth.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                Args:
         
     | 
| 18 | 
         
            +
                    center_radius (int | float, optional): Ground truth center size
         
     | 
| 19 | 
         
            +
                        to judge whether a prior is in center. Default 2.5.
         
     | 
| 20 | 
         
            +
                    candidate_topk (int, optional): The candidate top-k which used to
         
     | 
| 21 | 
         
            +
                        get top-k ious to calculate dynamic-k. Default 10.
         
     | 
| 22 | 
         
            +
                    iou_weight (int | float, optional): The scale factor for regression
         
     | 
| 23 | 
         
            +
                        iou cost. Default 3.0.
         
     | 
| 24 | 
         
            +
                    cls_weight (int | float, optional): The scale factor for classification
         
     | 
| 25 | 
         
            +
                        cost. Default 1.0.
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def __init__(self,
         
     | 
| 29 | 
         
            +
                             center_radius=2.5,
         
     | 
| 30 | 
         
            +
                             candidate_topk=10,
         
     | 
| 31 | 
         
            +
                             iou_weight=3.0,
         
     | 
| 32 | 
         
            +
                             cls_weight=1.0):
         
     | 
| 33 | 
         
            +
                    self.center_radius = center_radius
         
     | 
| 34 | 
         
            +
                    self.candidate_topk = candidate_topk
         
     | 
| 35 | 
         
            +
                    self.iou_weight = iou_weight
         
     | 
| 36 | 
         
            +
                    self.cls_weight = cls_weight
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def assign(self,
         
     | 
| 39 | 
         
            +
                           pred_scores,
         
     | 
| 40 | 
         
            +
                           priors,
         
     | 
| 41 | 
         
            +
                           decoded_bboxes,
         
     | 
| 42 | 
         
            +
                           gt_bboxes,
         
     | 
| 43 | 
         
            +
                           gt_labels,
         
     | 
| 44 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 45 | 
         
            +
                           eps=1e-7):
         
     | 
| 46 | 
         
            +
                    """Assign gt to priors using SimOTA. It will switch to CPU mode when
         
     | 
| 47 | 
         
            +
                    GPU is out of memory.
         
     | 
| 48 | 
         
            +
                    Args:
         
     | 
| 49 | 
         
            +
                        pred_scores (Tensor): Classification scores of one image,
         
     | 
| 50 | 
         
            +
                            a 2D-Tensor with shape [num_priors, num_classes]
         
     | 
| 51 | 
         
            +
                        priors (Tensor): All priors of one image, a 2D-Tensor with shape
         
     | 
| 52 | 
         
            +
                            [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
         
     | 
| 53 | 
         
            +
                        decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
         
     | 
| 54 | 
         
            +
                            [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
         
     | 
| 55 | 
         
            +
                        gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
         
     | 
| 56 | 
         
            +
                            with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
         
     | 
| 57 | 
         
            +
                        gt_labels (Tensor): Ground truth labels of one image, a Tensor
         
     | 
| 58 | 
         
            +
                            with shape [num_gts].
         
     | 
| 59 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 60 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 61 | 
         
            +
                        eps (float): A value added to the denominator for numerical
         
     | 
| 62 | 
         
            +
                            stability. Default 1e-7.
         
     | 
| 63 | 
         
            +
                    Returns:
         
     | 
| 64 | 
         
            +
                        assign_result (obj:`AssignResult`): The assigned result.
         
     | 
| 65 | 
         
            +
                    """
         
     | 
| 66 | 
         
            +
                    try:
         
     | 
| 67 | 
         
            +
                        assign_result = self._assign(pred_scores, priors, decoded_bboxes,
         
     | 
| 68 | 
         
            +
                                                     gt_bboxes, gt_labels,
         
     | 
| 69 | 
         
            +
                                                     gt_bboxes_ignore, eps)
         
     | 
| 70 | 
         
            +
                        return assign_result
         
     | 
| 71 | 
         
            +
                    except RuntimeError:
         
     | 
| 72 | 
         
            +
                        origin_device = pred_scores.device
         
     | 
| 73 | 
         
            +
                        warnings.warn('OOM RuntimeError is raised due to the huge memory '
         
     | 
| 74 | 
         
            +
                                      'cost during label assignment. CPU mode is applied '
         
     | 
| 75 | 
         
            +
                                      'in this batch. If you want to avoid this issue, '
         
     | 
| 76 | 
         
            +
                                      'try to reduce the batch size or image size.')
         
     | 
| 77 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                        pred_scores = pred_scores.cpu()
         
     | 
| 80 | 
         
            +
                        priors = priors.cpu()
         
     | 
| 81 | 
         
            +
                        decoded_bboxes = decoded_bboxes.cpu()
         
     | 
| 82 | 
         
            +
                        gt_bboxes = gt_bboxes.cpu().float()
         
     | 
| 83 | 
         
            +
                        gt_labels = gt_labels.cpu()
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                        assign_result = self._assign(pred_scores, priors, decoded_bboxes,
         
     | 
| 86 | 
         
            +
                                                     gt_bboxes, gt_labels,
         
     | 
| 87 | 
         
            +
                                                     gt_bboxes_ignore, eps)
         
     | 
| 88 | 
         
            +
                        assign_result.gt_inds = assign_result.gt_inds.to(origin_device)
         
     | 
| 89 | 
         
            +
                        assign_result.max_overlaps = assign_result.max_overlaps.to(
         
     | 
| 90 | 
         
            +
                            origin_device)
         
     | 
| 91 | 
         
            +
                        assign_result.labels = assign_result.labels.to(origin_device)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                        return assign_result
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def _assign(self,
         
     | 
| 96 | 
         
            +
                            pred_scores,
         
     | 
| 97 | 
         
            +
                            priors,
         
     | 
| 98 | 
         
            +
                            decoded_bboxes,
         
     | 
| 99 | 
         
            +
                            gt_bboxes,
         
     | 
| 100 | 
         
            +
                            gt_labels,
         
     | 
| 101 | 
         
            +
                            gt_bboxes_ignore=None,
         
     | 
| 102 | 
         
            +
                            eps=1e-7):
         
     | 
| 103 | 
         
            +
                    """Assign gt to priors using SimOTA.
         
     | 
| 104 | 
         
            +
                    Args:
         
     | 
| 105 | 
         
            +
                        pred_scores (Tensor): Classification scores of one image,
         
     | 
| 106 | 
         
            +
                            a 2D-Tensor with shape [num_priors, num_classes]
         
     | 
| 107 | 
         
            +
                        priors (Tensor): All priors of one image, a 2D-Tensor with shape
         
     | 
| 108 | 
         
            +
                            [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
         
     | 
| 109 | 
         
            +
                        decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
         
     | 
| 110 | 
         
            +
                            [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
         
     | 
| 111 | 
         
            +
                        gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
         
     | 
| 112 | 
         
            +
                            with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
         
     | 
| 113 | 
         
            +
                        gt_labels (Tensor): Ground truth labels of one image, a Tensor
         
     | 
| 114 | 
         
            +
                            with shape [num_gts].
         
     | 
| 115 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 116 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 117 | 
         
            +
                        eps (float): A value added to the denominator for numerical
         
     | 
| 118 | 
         
            +
                            stability. Default 1e-7.
         
     | 
| 119 | 
         
            +
                    Returns:
         
     | 
| 120 | 
         
            +
                        :obj:`AssignResult`: The assigned result.
         
     | 
| 121 | 
         
            +
                    """
         
     | 
| 122 | 
         
            +
                    INF = 100000.0
         
     | 
| 123 | 
         
            +
                    num_gt = gt_bboxes.size(0)
         
     | 
| 124 | 
         
            +
                    num_bboxes = decoded_bboxes.size(0)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    # assign 0 by default
         
     | 
| 127 | 
         
            +
                    assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
         
     | 
| 128 | 
         
            +
                                                               0,
         
     | 
| 129 | 
         
            +
                                                               dtype=torch.long)
         
     | 
| 130 | 
         
            +
                    valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
         
     | 
| 131 | 
         
            +
                        priors, gt_bboxes)
         
     | 
| 132 | 
         
            +
                    valid_decoded_bbox = decoded_bboxes[valid_mask]
         
     | 
| 133 | 
         
            +
                    valid_pred_scores = pred_scores[valid_mask]
         
     | 
| 134 | 
         
            +
                    num_valid = valid_decoded_bbox.size(0)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    if num_gt == 0 or num_bboxes == 0 or num_valid == 0:
         
     | 
| 137 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 138 | 
         
            +
                        max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
         
     | 
| 139 | 
         
            +
                        if num_gt == 0:
         
     | 
| 140 | 
         
            +
                            # No truth, assign everything to background
         
     | 
| 141 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 142 | 
         
            +
                        if gt_labels is None:
         
     | 
| 143 | 
         
            +
                            assigned_labels = None
         
     | 
| 144 | 
         
            +
                        else:
         
     | 
| 145 | 
         
            +
                            assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
         
     | 
| 146 | 
         
            +
                                                                      -1,
         
     | 
| 147 | 
         
            +
                                                                      dtype=torch.long)
         
     | 
| 148 | 
         
            +
                        return AssignResult(
         
     | 
| 149 | 
         
            +
                            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)
         
     | 
| 152 | 
         
            +
                    iou_cost = -torch.log(pairwise_ious + eps)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    gt_onehot_label = (
         
     | 
| 155 | 
         
            +
                        F.one_hot(gt_labels.to(torch.int64),
         
     | 
| 156 | 
         
            +
                                  pred_scores.shape[-1]).float().unsqueeze(0).repeat(
         
     | 
| 157 | 
         
            +
                                      num_valid, 1, 1))
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
         
     | 
| 160 | 
         
            +
                    cls_cost = (
         
     | 
| 161 | 
         
            +
                        F.binary_cross_entropy(
         
     | 
| 162 | 
         
            +
                            valid_pred_scores.to(dtype=torch.float32).sqrt_(),
         
     | 
| 163 | 
         
            +
                            gt_onehot_label,
         
     | 
| 164 | 
         
            +
                            reduction='none',
         
     | 
| 165 | 
         
            +
                        ).sum(-1).to(dtype=valid_pred_scores.dtype))
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    cost_matrix = (
         
     | 
| 168 | 
         
            +
                        cls_cost * self.cls_weight + iou_cost * self.iou_weight +
         
     | 
| 169 | 
         
            +
                        (~is_in_boxes_and_center) * INF)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    matched_pred_ious, matched_gt_inds = \
         
     | 
| 172 | 
         
            +
                        self.dynamic_k_matching(
         
     | 
| 173 | 
         
            +
                            cost_matrix, pairwise_ious, num_gt, valid_mask)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    # convert to AssignResult format
         
     | 
| 176 | 
         
            +
                    assigned_gt_inds[valid_mask] = matched_gt_inds + 1
         
     | 
| 177 | 
         
            +
                    assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
         
     | 
| 178 | 
         
            +
                    assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
         
     | 
| 179 | 
         
            +
                    max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
         
     | 
| 180 | 
         
            +
                                                             -INF,
         
     | 
| 181 | 
         
            +
                                                             dtype=torch.float32)
         
     | 
| 182 | 
         
            +
                    max_overlaps[valid_mask] = matched_pred_ious
         
     | 
| 183 | 
         
            +
                    return AssignResult(
         
     | 
| 184 | 
         
            +
                        num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                def get_in_gt_and_in_center_info(self, priors, gt_bboxes):
         
     | 
| 187 | 
         
            +
                    num_gt = gt_bboxes.size(0)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt)
         
     | 
| 190 | 
         
            +
                    repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt)
         
     | 
| 191 | 
         
            +
                    repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
         
     | 
| 192 | 
         
            +
                    repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # is prior centers in gt bboxes, shape: [n_prior, n_gt]
         
     | 
| 195 | 
         
            +
                    l_ = repeated_x - gt_bboxes[:, 0]
         
     | 
| 196 | 
         
            +
                    t_ = repeated_y - gt_bboxes[:, 1]
         
     | 
| 197 | 
         
            +
                    r_ = gt_bboxes[:, 2] - repeated_x
         
     | 
| 198 | 
         
            +
                    b_ = gt_bboxes[:, 3] - repeated_y
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    deltas = torch.stack([l_, t_, r_, b_], dim=1)
         
     | 
| 201 | 
         
            +
                    is_in_gts = deltas.min(dim=1).values > 0
         
     | 
| 202 | 
         
            +
                    is_in_gts_all = is_in_gts.sum(dim=1) > 0
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    # is prior centers in gt centers
         
     | 
| 205 | 
         
            +
                    gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
         
     | 
| 206 | 
         
            +
                    gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
         
     | 
| 207 | 
         
            +
                    ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
         
     | 
| 208 | 
         
            +
                    ct_box_t = gt_cys - self.center_radius * repeated_stride_y
         
     | 
| 209 | 
         
            +
                    ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
         
     | 
| 210 | 
         
            +
                    ct_box_b = gt_cys + self.center_radius * repeated_stride_y
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    cl_ = repeated_x - ct_box_l
         
     | 
| 213 | 
         
            +
                    ct_ = repeated_y - ct_box_t
         
     | 
| 214 | 
         
            +
                    cr_ = ct_box_r - repeated_x
         
     | 
| 215 | 
         
            +
                    cb_ = ct_box_b - repeated_y
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
         
     | 
| 218 | 
         
            +
                    is_in_cts = ct_deltas.min(dim=1).values > 0
         
     | 
| 219 | 
         
            +
                    is_in_cts_all = is_in_cts.sum(dim=1) > 0
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    # in boxes or in centers, shape: [num_priors]
         
     | 
| 222 | 
         
            +
                    is_in_gts_or_centers = is_in_gts_all | is_in_cts_all
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    # both in boxes and centers, shape: [num_fg, num_gt]
         
     | 
| 225 | 
         
            +
                    is_in_boxes_and_centers = (
         
     | 
| 226 | 
         
            +
                        is_in_gts[is_in_gts_or_centers, :]
         
     | 
| 227 | 
         
            +
                        & is_in_cts[is_in_gts_or_centers, :])
         
     | 
| 228 | 
         
            +
                    return is_in_gts_or_centers, is_in_boxes_and_centers
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
         
     | 
| 231 | 
         
            +
                    matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
         
     | 
| 232 | 
         
            +
                    # select candidate topk ious for dynamic-k calculation
         
     | 
| 233 | 
         
            +
                    candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
         
     | 
| 234 | 
         
            +
                    topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
         
     | 
| 235 | 
         
            +
                    # calculate dynamic k for each gt
         
     | 
| 236 | 
         
            +
                    dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
         
     | 
| 237 | 
         
            +
                    for gt_idx in range(num_gt):
         
     | 
| 238 | 
         
            +
                        _, pos_idx = torch.topk(
         
     | 
| 239 | 
         
            +
                            cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
         
     | 
| 240 | 
         
            +
                        matching_matrix[:, gt_idx][pos_idx] = 1
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    del topk_ious, dynamic_ks, pos_idx
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    prior_match_gt_mask = matching_matrix.sum(1) > 1
         
     | 
| 245 | 
         
            +
                    if prior_match_gt_mask.sum() > 0:
         
     | 
| 246 | 
         
            +
                        cost_min, cost_argmin = torch.min(
         
     | 
| 247 | 
         
            +
                            cost[prior_match_gt_mask, :], dim=1)
         
     | 
| 248 | 
         
            +
                        matching_matrix[prior_match_gt_mask, :] *= 0
         
     | 
| 249 | 
         
            +
                        matching_matrix[prior_match_gt_mask, cost_argmin] = 1
         
     | 
| 250 | 
         
            +
                    # get foreground mask inside box and center prior
         
     | 
| 251 | 
         
            +
                    fg_mask_inboxes = matching_matrix.sum(1) > 0
         
     | 
| 252 | 
         
            +
                    valid_mask[valid_mask.clone()] = fg_mask_inboxes
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
         
     | 
| 255 | 
         
            +
                    matched_pred_ious = (matching_matrix *
         
     | 
| 256 | 
         
            +
                                         pairwise_ious).sum(1)[fg_mask_inboxes]
         
     | 
| 257 | 
         
            +
                    return matched_pred_ious, matched_gt_inds
         
     | 
    	
        mmdet/core/bbox/assigners/task_aligned_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,151 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 6 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 7 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            INF = 100000000
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 13 | 
         
            +
            class TaskAlignedAssigner(BaseAssigner):
         
     | 
| 14 | 
         
            +
                """Task aligned assigner used in the paper:
         
     | 
| 15 | 
         
            +
                `TOOD: Task-aligned One-stage Object Detection.
         
     | 
| 16 | 
         
            +
                <https://arxiv.org/abs/2108.07755>`_.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                Assign a corresponding gt bbox or background to each predicted bbox.
         
     | 
| 19 | 
         
            +
                Each bbox will be assigned with `0` or a positive integer
         
     | 
| 20 | 
         
            +
                indicating the ground truth index.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                - 0: negative sample, no assigned gt
         
     | 
| 23 | 
         
            +
                - positive integer: positive sample, index (1-based) of assigned gt
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Args:
         
     | 
| 26 | 
         
            +
                    topk (int): number of bbox selected in each level
         
     | 
| 27 | 
         
            +
                    iou_calculator (dict): Config dict for iou calculator.
         
     | 
| 28 | 
         
            +
                        Default: dict(type='BboxOverlaps2D')
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def __init__(self, topk, iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 32 | 
         
            +
                    assert topk >= 1
         
     | 
| 33 | 
         
            +
                    self.topk = topk
         
     | 
| 34 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def assign(self,
         
     | 
| 37 | 
         
            +
                           pred_scores,
         
     | 
| 38 | 
         
            +
                           decode_bboxes,
         
     | 
| 39 | 
         
            +
                           anchors,
         
     | 
| 40 | 
         
            +
                           gt_bboxes,
         
     | 
| 41 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 42 | 
         
            +
                           gt_labels=None,
         
     | 
| 43 | 
         
            +
                           alpha=1,
         
     | 
| 44 | 
         
            +
                           beta=6):
         
     | 
| 45 | 
         
            +
                    """Assign gt to bboxes.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    The assignment is done in following steps
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    1. compute alignment metric between all bbox (bbox of all pyramid
         
     | 
| 50 | 
         
            +
                       levels) and gt
         
     | 
| 51 | 
         
            +
                    2. select top-k bbox as candidates for each gt
         
     | 
| 52 | 
         
            +
                    3. limit the positive sample's center in gt (because the anchor-free
         
     | 
| 53 | 
         
            +
                       detector only can predict positive distance)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    Args:
         
     | 
| 57 | 
         
            +
                        pred_scores (Tensor): predicted class probability,
         
     | 
| 58 | 
         
            +
                            shape(n, num_classes)
         
     | 
| 59 | 
         
            +
                        decode_bboxes (Tensor): predicted bounding boxes, shape(n, 4)
         
     | 
| 60 | 
         
            +
                        anchors (Tensor): pre-defined anchors, shape(n, 4).
         
     | 
| 61 | 
         
            +
                        gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
         
     | 
| 62 | 
         
            +
                        gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
         
     | 
| 63 | 
         
            +
                            labelled as `ignored`, e.g., crowd boxes in COCO.
         
     | 
| 64 | 
         
            +
                        gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    Returns:
         
     | 
| 67 | 
         
            +
                        :obj:`TaskAlignedAssignResult`: The assign result.
         
     | 
| 68 | 
         
            +
                    """
         
     | 
| 69 | 
         
            +
                    anchors = anchors[:, :4]
         
     | 
| 70 | 
         
            +
                    num_gt, num_bboxes = gt_bboxes.size(0), anchors.size(0)
         
     | 
| 71 | 
         
            +
                    # compute alignment metric between all bbox and gt
         
     | 
| 72 | 
         
            +
                    overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach()
         
     | 
| 73 | 
         
            +
                    bbox_scores = pred_scores[:, gt_labels].detach()
         
     | 
| 74 | 
         
            +
                    # assign 0 by default
         
     | 
| 75 | 
         
            +
                    assigned_gt_inds = anchors.new_full((num_bboxes, ),
         
     | 
| 76 | 
         
            +
                                                        0,
         
     | 
| 77 | 
         
            +
                                                        dtype=torch.long)
         
     | 
| 78 | 
         
            +
                    assign_metrics = anchors.new_zeros((num_bboxes, ))
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    if num_gt == 0 or num_bboxes == 0:
         
     | 
| 81 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 82 | 
         
            +
                        max_overlaps = anchors.new_zeros((num_bboxes, ))
         
     | 
| 83 | 
         
            +
                        if num_gt == 0:
         
     | 
| 84 | 
         
            +
                            # No gt boxes, assign everything to background
         
     | 
| 85 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 86 | 
         
            +
                        if gt_labels is None:
         
     | 
| 87 | 
         
            +
                            assigned_labels = None
         
     | 
| 88 | 
         
            +
                        else:
         
     | 
| 89 | 
         
            +
                            assigned_labels = anchors.new_full((num_bboxes, ),
         
     | 
| 90 | 
         
            +
                                                               -1,
         
     | 
| 91 | 
         
            +
                                                               dtype=torch.long)
         
     | 
| 92 | 
         
            +
                        assign_result = AssignResult(
         
     | 
| 93 | 
         
            +
                            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
| 94 | 
         
            +
                        assign_result.assign_metrics = assign_metrics
         
     | 
| 95 | 
         
            +
                        return assign_result
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # select top-k bboxes as candidates for each gt
         
     | 
| 98 | 
         
            +
                    alignment_metrics = bbox_scores**alpha * overlaps**beta
         
     | 
| 99 | 
         
            +
                    topk = min(self.topk, alignment_metrics.size(0))
         
     | 
| 100 | 
         
            +
                    _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True)
         
     | 
| 101 | 
         
            +
                    candidate_metrics = alignment_metrics[candidate_idxs,
         
     | 
| 102 | 
         
            +
                                                          torch.arange(num_gt)]
         
     | 
| 103 | 
         
            +
                    is_pos = candidate_metrics > 0
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # limit the positive sample's center in gt
         
     | 
| 106 | 
         
            +
                    anchors_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0
         
     | 
| 107 | 
         
            +
                    anchors_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0
         
     | 
| 108 | 
         
            +
                    for gt_idx in range(num_gt):
         
     | 
| 109 | 
         
            +
                        candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
         
     | 
| 110 | 
         
            +
                    ep_anchors_cx = anchors_cx.view(1, -1).expand(
         
     | 
| 111 | 
         
            +
                        num_gt, num_bboxes).contiguous().view(-1)
         
     | 
| 112 | 
         
            +
                    ep_anchors_cy = anchors_cy.view(1, -1).expand(
         
     | 
| 113 | 
         
            +
                        num_gt, num_bboxes).contiguous().view(-1)
         
     | 
| 114 | 
         
            +
                    candidate_idxs = candidate_idxs.view(-1)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    # calculate the left, top, right, bottom distance between positive
         
     | 
| 117 | 
         
            +
                    # bbox center and gt side
         
     | 
| 118 | 
         
            +
                    l_ = ep_anchors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
         
     | 
| 119 | 
         
            +
                    t_ = ep_anchors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
         
     | 
| 120 | 
         
            +
                    r_ = gt_bboxes[:, 2] - ep_anchors_cx[candidate_idxs].view(-1, num_gt)
         
     | 
| 121 | 
         
            +
                    b_ = gt_bboxes[:, 3] - ep_anchors_cy[candidate_idxs].view(-1, num_gt)
         
     | 
| 122 | 
         
            +
                    is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
         
     | 
| 123 | 
         
            +
                    is_pos = is_pos & is_in_gts
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    # if an anchor box is assigned to multiple gts,
         
     | 
| 126 | 
         
            +
                    # the one with the highest iou will be selected.
         
     | 
| 127 | 
         
            +
                    overlaps_inf = torch.full_like(overlaps,
         
     | 
| 128 | 
         
            +
                                                   -INF).t().contiguous().view(-1)
         
     | 
| 129 | 
         
            +
                    index = candidate_idxs.view(-1)[is_pos.view(-1)]
         
     | 
| 130 | 
         
            +
                    overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
         
     | 
| 131 | 
         
            +
                    overlaps_inf = overlaps_inf.view(num_gt, -1).t()
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
         
     | 
| 134 | 
         
            +
                    assigned_gt_inds[
         
     | 
| 135 | 
         
            +
                        max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
         
     | 
| 136 | 
         
            +
                    assign_metrics[max_overlaps != -INF] = alignment_metrics[
         
     | 
| 137 | 
         
            +
                        max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]]
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 140 | 
         
            +
                        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
         
     | 
| 141 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 142 | 
         
            +
                            assigned_gt_inds > 0, as_tuple=False).squeeze()
         
     | 
| 143 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 144 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[
         
     | 
| 145 | 
         
            +
                                assigned_gt_inds[pos_inds] - 1]
         
     | 
| 146 | 
         
            +
                    else:
         
     | 
| 147 | 
         
            +
                        assigned_labels = None
         
     | 
| 148 | 
         
            +
                    assign_result = AssignResult(
         
     | 
| 149 | 
         
            +
                        num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
         
     | 
| 150 | 
         
            +
                    assign_result.assign_metrics = assign_metrics
         
     | 
| 151 | 
         
            +
                    return assign_result
         
     | 
    	
        mmdet/core/bbox/assigners/uniform_assigner.py
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..builder import BBOX_ASSIGNERS
         
     | 
| 5 | 
         
            +
            from ..iou_calculators import build_iou_calculator
         
     | 
| 6 | 
         
            +
            from ..transforms import bbox_xyxy_to_cxcywh
         
     | 
| 7 | 
         
            +
            from .assign_result import AssignResult
         
     | 
| 8 | 
         
            +
            from .base_assigner import BaseAssigner
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            @BBOX_ASSIGNERS.register_module()
         
     | 
| 12 | 
         
            +
            class UniformAssigner(BaseAssigner):
         
     | 
| 13 | 
         
            +
                """Uniform Matching between the anchors and gt boxes, which can achieve
         
     | 
| 14 | 
         
            +
                balance in positive anchors, and gt_bboxes_ignore was not considered for
         
     | 
| 15 | 
         
            +
                now.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                Args:
         
     | 
| 18 | 
         
            +
                    pos_ignore_thr (float): the threshold to ignore positive anchors
         
     | 
| 19 | 
         
            +
                    neg_ignore_thr (float): the threshold to ignore negative anchors
         
     | 
| 20 | 
         
            +
                    match_times(int): Number of positive anchors for each gt box.
         
     | 
| 21 | 
         
            +
                       Default 4.
         
     | 
| 22 | 
         
            +
                    iou_calculator (dict): iou_calculator config
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def __init__(self,
         
     | 
| 26 | 
         
            +
                             pos_ignore_thr,
         
     | 
| 27 | 
         
            +
                             neg_ignore_thr,
         
     | 
| 28 | 
         
            +
                             match_times=4,
         
     | 
| 29 | 
         
            +
                             iou_calculator=dict(type='BboxOverlaps2D')):
         
     | 
| 30 | 
         
            +
                    self.match_times = match_times
         
     | 
| 31 | 
         
            +
                    self.pos_ignore_thr = pos_ignore_thr
         
     | 
| 32 | 
         
            +
                    self.neg_ignore_thr = neg_ignore_thr
         
     | 
| 33 | 
         
            +
                    self.iou_calculator = build_iou_calculator(iou_calculator)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def assign(self,
         
     | 
| 36 | 
         
            +
                           bbox_pred,
         
     | 
| 37 | 
         
            +
                           anchor,
         
     | 
| 38 | 
         
            +
                           gt_bboxes,
         
     | 
| 39 | 
         
            +
                           gt_bboxes_ignore=None,
         
     | 
| 40 | 
         
            +
                           gt_labels=None):
         
     | 
| 41 | 
         
            +
                    num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    # 1. assign -1 by default
         
     | 
| 44 | 
         
            +
                    assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
         
     | 
| 45 | 
         
            +
                                                          0,
         
     | 
| 46 | 
         
            +
                                                          dtype=torch.long)
         
     | 
| 47 | 
         
            +
                    assigned_labels = bbox_pred.new_full((num_bboxes, ),
         
     | 
| 48 | 
         
            +
                                                         -1,
         
     | 
| 49 | 
         
            +
                                                         dtype=torch.long)
         
     | 
| 50 | 
         
            +
                    if num_gts == 0 or num_bboxes == 0:
         
     | 
| 51 | 
         
            +
                        # No ground truth or boxes, return empty assignment
         
     | 
| 52 | 
         
            +
                        if num_gts == 0:
         
     | 
| 53 | 
         
            +
                            # No ground truth, assign all to background
         
     | 
| 54 | 
         
            +
                            assigned_gt_inds[:] = 0
         
     | 
| 55 | 
         
            +
                        assign_result = AssignResult(
         
     | 
| 56 | 
         
            +
                            num_gts, assigned_gt_inds, None, labels=assigned_labels)
         
     | 
| 57 | 
         
            +
                        assign_result.set_extra_property(
         
     | 
| 58 | 
         
            +
                            'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool))
         
     | 
| 59 | 
         
            +
                        assign_result.set_extra_property('pos_predicted_boxes',
         
     | 
| 60 | 
         
            +
                                                         bbox_pred.new_empty((0, 4)))
         
     | 
| 61 | 
         
            +
                        assign_result.set_extra_property('target_boxes',
         
     | 
| 62 | 
         
            +
                                                         bbox_pred.new_empty((0, 4)))
         
     | 
| 63 | 
         
            +
                        return assign_result
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    # 2. Compute the L1 cost between boxes
         
     | 
| 66 | 
         
            +
                    # Note that we use anchors and predict boxes both
         
     | 
| 67 | 
         
            +
                    cost_bbox = torch.cdist(
         
     | 
| 68 | 
         
            +
                        bbox_xyxy_to_cxcywh(bbox_pred),
         
     | 
| 69 | 
         
            +
                        bbox_xyxy_to_cxcywh(gt_bboxes),
         
     | 
| 70 | 
         
            +
                        p=1)
         
     | 
| 71 | 
         
            +
                    cost_bbox_anchors = torch.cdist(
         
     | 
| 72 | 
         
            +
                        bbox_xyxy_to_cxcywh(anchor), bbox_xyxy_to_cxcywh(gt_bboxes), p=1)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # We found that topk function has different results in cpu and
         
     | 
| 75 | 
         
            +
                    # cuda mode. In order to ensure consistency with the source code,
         
     | 
| 76 | 
         
            +
                    # we also use cpu mode.
         
     | 
| 77 | 
         
            +
                    # TODO: Check whether the performance of cpu and cuda are the same.
         
     | 
| 78 | 
         
            +
                    C = cost_bbox.cpu()
         
     | 
| 79 | 
         
            +
                    C1 = cost_bbox_anchors.cpu()
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # self.match_times x n
         
     | 
| 82 | 
         
            +
                    index = torch.topk(
         
     | 
| 83 | 
         
            +
                        C,  # c=b,n,x c[i]=n,x
         
     | 
| 84 | 
         
            +
                        k=self.match_times,
         
     | 
| 85 | 
         
            +
                        dim=0,
         
     | 
| 86 | 
         
            +
                        largest=False)[1]
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # self.match_times x n
         
     | 
| 89 | 
         
            +
                    index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1]
         
     | 
| 90 | 
         
            +
                    # (self.match_times*2) x n
         
     | 
| 91 | 
         
            +
                    indexes = torch.cat((index, index1),
         
     | 
| 92 | 
         
            +
                                        dim=1).reshape(-1).to(bbox_pred.device)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes)
         
     | 
| 95 | 
         
            +
                    anchor_overlaps = self.iou_calculator(anchor, gt_bboxes)
         
     | 
| 96 | 
         
            +
                    pred_max_overlaps, _ = pred_overlaps.max(dim=1)
         
     | 
| 97 | 
         
            +
                    anchor_max_overlaps, _ = anchor_overlaps.max(dim=0)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # 3. Compute the ignore indexes use gt_bboxes and predict boxes
         
     | 
| 100 | 
         
            +
                    ignore_idx = pred_max_overlaps > self.neg_ignore_thr
         
     | 
| 101 | 
         
            +
                    assigned_gt_inds[ignore_idx] = -1
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # 4. Compute the ignore indexes of positive sample use anchors
         
     | 
| 104 | 
         
            +
                    # and predict boxes
         
     | 
| 105 | 
         
            +
                    pos_gt_index = torch.arange(
         
     | 
| 106 | 
         
            +
                        0, C1.size(1),
         
     | 
| 107 | 
         
            +
                        device=bbox_pred.device).repeat(self.match_times * 2)
         
     | 
| 108 | 
         
            +
                    pos_ious = anchor_overlaps[indexes, pos_gt_index]
         
     | 
| 109 | 
         
            +
                    pos_ignore_idx = pos_ious < self.pos_ignore_thr
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    pos_gt_index_with_ignore = pos_gt_index + 1
         
     | 
| 112 | 
         
            +
                    pos_gt_index_with_ignore[pos_ignore_idx] = -1
         
     | 
| 113 | 
         
            +
                    assigned_gt_inds[indexes] = pos_gt_index_with_ignore
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    if gt_labels is not None:
         
     | 
| 116 | 
         
            +
                        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
         
     | 
| 117 | 
         
            +
                        pos_inds = torch.nonzero(
         
     | 
| 118 | 
         
            +
                            assigned_gt_inds > 0, as_tuple=False).squeeze()
         
     | 
| 119 | 
         
            +
                        if pos_inds.numel() > 0:
         
     | 
| 120 | 
         
            +
                            assigned_labels[pos_inds] = gt_labels[
         
     | 
| 121 | 
         
            +
                                assigned_gt_inds[pos_inds] - 1]
         
     | 
| 122 | 
         
            +
                    else:
         
     | 
| 123 | 
         
            +
                        assigned_labels = None
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    assign_result = AssignResult(
         
     | 
| 126 | 
         
            +
                        num_gts,
         
     | 
| 127 | 
         
            +
                        assigned_gt_inds,
         
     | 
| 128 | 
         
            +
                        anchor_max_overlaps,
         
     | 
| 129 | 
         
            +
                        labels=assigned_labels)
         
     | 
| 130 | 
         
            +
                    assign_result.set_extra_property('pos_idx', ~pos_ignore_idx)
         
     | 
| 131 | 
         
            +
                    assign_result.set_extra_property('pos_predicted_boxes',
         
     | 
| 132 | 
         
            +
                                                     bbox_pred[indexes])
         
     | 
| 133 | 
         
            +
                    assign_result.set_extra_property('target_boxes',
         
     | 
| 134 | 
         
            +
                                                     gt_bboxes[pos_gt_index])
         
     | 
| 135 | 
         
            +
                    return assign_result
         
     | 
    	
        mmdet/core/bbox/builder.py
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from mmcv.utils import Registry, build_from_cfg
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            BBOX_ASSIGNERS = Registry('bbox_assigner')
         
     | 
| 5 | 
         
            +
            BBOX_SAMPLERS = Registry('bbox_sampler')
         
     | 
| 6 | 
         
            +
            BBOX_CODERS = Registry('bbox_coder')
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def build_assigner(cfg, **default_args):
         
     | 
| 10 | 
         
            +
                """Builder of box assigner."""
         
     | 
| 11 | 
         
            +
                return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def build_sampler(cfg, **default_args):
         
     | 
| 15 | 
         
            +
                """Builder of box sampler."""
         
     | 
| 16 | 
         
            +
                return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def build_bbox_coder(cfg, **default_args):
         
     | 
| 20 | 
         
            +
                """Builder of box coder."""
         
     | 
| 21 | 
         
            +
                return build_from_cfg(cfg, BBOX_CODERS, default_args)
         
     | 
    	
        mmdet/core/bbox/coder/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 3 | 
         
            +
            from .bucketing_bbox_coder import BucketingBBoxCoder
         
     | 
| 4 | 
         
            +
            from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
         
     | 
| 5 | 
         
            +
            from .distance_point_bbox_coder import DistancePointBBoxCoder
         
     | 
| 6 | 
         
            +
            from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
         
     | 
| 7 | 
         
            +
            from .pseudo_bbox_coder import PseudoBBoxCoder
         
     | 
| 8 | 
         
            +
            from .tblr_bbox_coder import TBLRBBoxCoder
         
     | 
| 9 | 
         
            +
            from .yolo_bbox_coder import YOLOBBoxCoder
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            __all__ = [
         
     | 
| 12 | 
         
            +
                'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
         
     | 
| 13 | 
         
            +
                'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
         
     | 
| 14 | 
         
            +
                'BucketingBBoxCoder', 'DistancePointBBoxCoder'
         
     | 
| 15 | 
         
            +
            ]
         
     | 
    	
        mmdet/core/bbox/coder/base_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,18 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from abc import ABCMeta, abstractmethod
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class BaseBBoxCoder(metaclass=ABCMeta):
         
     | 
| 6 | 
         
            +
                """Base bounding box coder."""
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                def __init__(self, **kwargs):
         
     | 
| 9 | 
         
            +
                    pass
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                @abstractmethod
         
     | 
| 12 | 
         
            +
                def encode(self, bboxes, gt_bboxes):
         
     | 
| 13 | 
         
            +
                    """Encode deltas between bboxes and ground truth boxes."""
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                @abstractmethod
         
     | 
| 16 | 
         
            +
                def decode(self, bboxes, bboxes_pred):
         
     | 
| 17 | 
         
            +
                    """Decode the predicted bboxes according to prediction and base
         
     | 
| 18 | 
         
            +
                    boxes."""
         
     | 
    	
        mmdet/core/bbox/coder/bucketing_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,351 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import mmcv
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 8 | 
         
            +
            from ..transforms import bbox_rescale
         
     | 
| 9 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 13 | 
         
            +
            class BucketingBBoxCoder(BaseBBoxCoder):
         
     | 
| 14 | 
         
            +
                """Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Boundary Localization with Bucketing and Bucketing Guided Rescoring
         
     | 
| 17 | 
         
            +
                are implemented here.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                Please refer to https://arxiv.org/abs/1912.04260 for more details.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    num_buckets (int): Number of buckets.
         
     | 
| 23 | 
         
            +
                    scale_factor (int): Scale factor of proposals to generate buckets.
         
     | 
| 24 | 
         
            +
                    offset_topk (int): Topk buckets are used to generate
         
     | 
| 25 | 
         
            +
                         bucket fine regression targets. Defaults to 2.
         
     | 
| 26 | 
         
            +
                    offset_upperbound (float): Offset upperbound to generate
         
     | 
| 27 | 
         
            +
                         bucket fine regression targets.
         
     | 
| 28 | 
         
            +
                         To avoid too large offset displacements. Defaults to 1.0.
         
     | 
| 29 | 
         
            +
                    cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
         
     | 
| 30 | 
         
            +
                         Defaults to True.
         
     | 
| 31 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 32 | 
         
            +
                        border of the image. Defaults to True.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def __init__(self,
         
     | 
| 36 | 
         
            +
                             num_buckets,
         
     | 
| 37 | 
         
            +
                             scale_factor,
         
     | 
| 38 | 
         
            +
                             offset_topk=2,
         
     | 
| 39 | 
         
            +
                             offset_upperbound=1.0,
         
     | 
| 40 | 
         
            +
                             cls_ignore_neighbor=True,
         
     | 
| 41 | 
         
            +
                             clip_border=True):
         
     | 
| 42 | 
         
            +
                    super(BucketingBBoxCoder, self).__init__()
         
     | 
| 43 | 
         
            +
                    self.num_buckets = num_buckets
         
     | 
| 44 | 
         
            +
                    self.scale_factor = scale_factor
         
     | 
| 45 | 
         
            +
                    self.offset_topk = offset_topk
         
     | 
| 46 | 
         
            +
                    self.offset_upperbound = offset_upperbound
         
     | 
| 47 | 
         
            +
                    self.cls_ignore_neighbor = cls_ignore_neighbor
         
     | 
| 48 | 
         
            +
                    self.clip_border = clip_border
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def encode(self, bboxes, gt_bboxes):
         
     | 
| 51 | 
         
            +
                    """Get bucketing estimation and fine regression targets during
         
     | 
| 52 | 
         
            +
                    training.
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    Args:
         
     | 
| 55 | 
         
            +
                        bboxes (torch.Tensor): source boxes, e.g., object proposals.
         
     | 
| 56 | 
         
            +
                        gt_bboxes (torch.Tensor): target of the transformation, e.g.,
         
     | 
| 57 | 
         
            +
                            ground truth boxes.
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    Returns:
         
     | 
| 60 | 
         
            +
                       encoded_bboxes(tuple[Tensor]): bucketing estimation
         
     | 
| 61 | 
         
            +
                        and fine regression targets and weights
         
     | 
| 62 | 
         
            +
                    """
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    assert bboxes.size(0) == gt_bboxes.size(0)
         
     | 
| 65 | 
         
            +
                    assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
         
     | 
| 66 | 
         
            +
                    encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
         
     | 
| 67 | 
         
            +
                                                 self.scale_factor, self.offset_topk,
         
     | 
| 68 | 
         
            +
                                                 self.offset_upperbound,
         
     | 
| 69 | 
         
            +
                                                 self.cls_ignore_neighbor)
         
     | 
| 70 | 
         
            +
                    return encoded_bboxes
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def decode(self, bboxes, pred_bboxes, max_shape=None):
         
     | 
| 73 | 
         
            +
                    """Apply transformation `pred_bboxes` to `boxes`.
         
     | 
| 74 | 
         
            +
                    Args:
         
     | 
| 75 | 
         
            +
                        boxes (torch.Tensor): Basic boxes.
         
     | 
| 76 | 
         
            +
                        pred_bboxes (torch.Tensor): Predictions for bucketing estimation
         
     | 
| 77 | 
         
            +
                            and fine regression
         
     | 
| 78 | 
         
            +
                        max_shape (tuple[int], optional): Maximum shape of boxes.
         
     | 
| 79 | 
         
            +
                            Defaults to None.
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    Returns:
         
     | 
| 82 | 
         
            +
                        torch.Tensor: Decoded boxes.
         
     | 
| 83 | 
         
            +
                    """
         
     | 
| 84 | 
         
            +
                    assert len(pred_bboxes) == 2
         
     | 
| 85 | 
         
            +
                    cls_preds, offset_preds = pred_bboxes
         
     | 
| 86 | 
         
            +
                    assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
         
     | 
| 87 | 
         
            +
                        0) == bboxes.size(0)
         
     | 
| 88 | 
         
            +
                    decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
         
     | 
| 89 | 
         
            +
                                                 self.num_buckets, self.scale_factor,
         
     | 
| 90 | 
         
            +
                                                 max_shape, self.clip_border)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    return decoded_bboxes
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 96 | 
         
            +
            def generat_buckets(proposals, num_buckets, scale_factor=1.0):
         
     | 
| 97 | 
         
            +
                """Generate buckets w.r.t bucket number and scale factor of proposals.
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                Args:
         
     | 
| 100 | 
         
            +
                    proposals (Tensor): Shape (n, 4)
         
     | 
| 101 | 
         
            +
                    num_buckets (int): Number of buckets.
         
     | 
| 102 | 
         
            +
                    scale_factor (float): Scale factor to rescale proposals.
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                Returns:
         
     | 
| 105 | 
         
            +
                    tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
         
     | 
| 106 | 
         
            +
                     t_buckets, d_buckets)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                        - bucket_w: Width of buckets on x-axis. Shape (n, ).
         
     | 
| 109 | 
         
            +
                        - bucket_h: Height of buckets on y-axis. Shape (n, ).
         
     | 
| 110 | 
         
            +
                        - l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
         
     | 
| 111 | 
         
            +
                        - r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
         
     | 
| 112 | 
         
            +
                        - t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
         
     | 
| 113 | 
         
            +
                        - d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
         
     | 
| 114 | 
         
            +
                """
         
     | 
| 115 | 
         
            +
                proposals = bbox_rescale(proposals, scale_factor)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                # number of buckets in each side
         
     | 
| 118 | 
         
            +
                side_num = int(np.ceil(num_buckets / 2.0))
         
     | 
| 119 | 
         
            +
                pw = proposals[..., 2] - proposals[..., 0]
         
     | 
| 120 | 
         
            +
                ph = proposals[..., 3] - proposals[..., 1]
         
     | 
| 121 | 
         
            +
                px1 = proposals[..., 0]
         
     | 
| 122 | 
         
            +
                py1 = proposals[..., 1]
         
     | 
| 123 | 
         
            +
                px2 = proposals[..., 2]
         
     | 
| 124 | 
         
            +
                py2 = proposals[..., 3]
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                bucket_w = pw / num_buckets
         
     | 
| 127 | 
         
            +
                bucket_h = ph / num_buckets
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                # left buckets
         
     | 
| 130 | 
         
            +
                l_buckets = px1[:, None] + (0.5 + torch.arange(
         
     | 
| 131 | 
         
            +
                    0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
         
     | 
| 132 | 
         
            +
                # right buckets
         
     | 
| 133 | 
         
            +
                r_buckets = px2[:, None] - (0.5 + torch.arange(
         
     | 
| 134 | 
         
            +
                    0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
         
     | 
| 135 | 
         
            +
                # top buckets
         
     | 
| 136 | 
         
            +
                t_buckets = py1[:, None] + (0.5 + torch.arange(
         
     | 
| 137 | 
         
            +
                    0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
         
     | 
| 138 | 
         
            +
                # down buckets
         
     | 
| 139 | 
         
            +
                d_buckets = py2[:, None] - (0.5 + torch.arange(
         
     | 
| 140 | 
         
            +
                    0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
         
     | 
| 141 | 
         
            +
                return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 145 | 
         
            +
            def bbox2bucket(proposals,
         
     | 
| 146 | 
         
            +
                            gt,
         
     | 
| 147 | 
         
            +
                            num_buckets,
         
     | 
| 148 | 
         
            +
                            scale_factor,
         
     | 
| 149 | 
         
            +
                            offset_topk=2,
         
     | 
| 150 | 
         
            +
                            offset_upperbound=1.0,
         
     | 
| 151 | 
         
            +
                            cls_ignore_neighbor=True):
         
     | 
| 152 | 
         
            +
                """Generate buckets estimation and fine regression targets.
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                Args:
         
     | 
| 155 | 
         
            +
                    proposals (Tensor): Shape (n, 4)
         
     | 
| 156 | 
         
            +
                    gt (Tensor): Shape (n, 4)
         
     | 
| 157 | 
         
            +
                    num_buckets (int): Number of buckets.
         
     | 
| 158 | 
         
            +
                    scale_factor (float): Scale factor to rescale proposals.
         
     | 
| 159 | 
         
            +
                    offset_topk (int): Topk buckets are used to generate
         
     | 
| 160 | 
         
            +
                         bucket fine regression targets. Defaults to 2.
         
     | 
| 161 | 
         
            +
                    offset_upperbound (float): Offset allowance to generate
         
     | 
| 162 | 
         
            +
                         bucket fine regression targets.
         
     | 
| 163 | 
         
            +
                         To avoid too large offset displacements. Defaults to 1.0.
         
     | 
| 164 | 
         
            +
                    cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
         
     | 
| 165 | 
         
            +
                         Defaults to True.
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                Returns:
         
     | 
| 168 | 
         
            +
                    tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                        - offsets: Fine regression targets. \
         
     | 
| 171 | 
         
            +
                            Shape (n, num_buckets*2).
         
     | 
| 172 | 
         
            +
                        - offsets_weights: Fine regression weights. \
         
     | 
| 173 | 
         
            +
                            Shape (n, num_buckets*2).
         
     | 
| 174 | 
         
            +
                        - bucket_labels: Bucketing estimation labels. \
         
     | 
| 175 | 
         
            +
                            Shape (n, num_buckets*2).
         
     | 
| 176 | 
         
            +
                        - cls_weights: Bucketing estimation weights. \
         
     | 
| 177 | 
         
            +
                            Shape (n, num_buckets*2).
         
     | 
| 178 | 
         
            +
                """
         
     | 
| 179 | 
         
            +
                assert proposals.size() == gt.size()
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                # generate buckets
         
     | 
| 182 | 
         
            +
                proposals = proposals.float()
         
     | 
| 183 | 
         
            +
                gt = gt.float()
         
     | 
| 184 | 
         
            +
                (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
         
     | 
| 185 | 
         
            +
                 d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                gx1 = gt[..., 0]
         
     | 
| 188 | 
         
            +
                gy1 = gt[..., 1]
         
     | 
| 189 | 
         
            +
                gx2 = gt[..., 2]
         
     | 
| 190 | 
         
            +
                gy2 = gt[..., 3]
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                # generate offset targets and weights
         
     | 
| 193 | 
         
            +
                # offsets from buckets to gts
         
     | 
| 194 | 
         
            +
                l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
         
     | 
| 195 | 
         
            +
                r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
         
     | 
| 196 | 
         
            +
                t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
         
     | 
| 197 | 
         
            +
                d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                # select top-k nearest buckets
         
     | 
| 200 | 
         
            +
                l_topk, l_label = l_offsets.abs().topk(
         
     | 
| 201 | 
         
            +
                    offset_topk, dim=1, largest=False, sorted=True)
         
     | 
| 202 | 
         
            +
                r_topk, r_label = r_offsets.abs().topk(
         
     | 
| 203 | 
         
            +
                    offset_topk, dim=1, largest=False, sorted=True)
         
     | 
| 204 | 
         
            +
                t_topk, t_label = t_offsets.abs().topk(
         
     | 
| 205 | 
         
            +
                    offset_topk, dim=1, largest=False, sorted=True)
         
     | 
| 206 | 
         
            +
                d_topk, d_label = d_offsets.abs().topk(
         
     | 
| 207 | 
         
            +
                    offset_topk, dim=1, largest=False, sorted=True)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                offset_l_weights = l_offsets.new_zeros(l_offsets.size())
         
     | 
| 210 | 
         
            +
                offset_r_weights = r_offsets.new_zeros(r_offsets.size())
         
     | 
| 211 | 
         
            +
                offset_t_weights = t_offsets.new_zeros(t_offsets.size())
         
     | 
| 212 | 
         
            +
                offset_d_weights = d_offsets.new_zeros(d_offsets.size())
         
     | 
| 213 | 
         
            +
                inds = torch.arange(0, proposals.size(0)).to(proposals).long()
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                # generate offset weights of top-k nearest buckets
         
     | 
| 216 | 
         
            +
                for k in range(offset_topk):
         
     | 
| 217 | 
         
            +
                    if k >= 1:
         
     | 
| 218 | 
         
            +
                        offset_l_weights[inds, l_label[:,
         
     | 
| 219 | 
         
            +
                                                       k]] = (l_topk[:, k] <
         
     | 
| 220 | 
         
            +
                                                              offset_upperbound).float()
         
     | 
| 221 | 
         
            +
                        offset_r_weights[inds, r_label[:,
         
     | 
| 222 | 
         
            +
                                                       k]] = (r_topk[:, k] <
         
     | 
| 223 | 
         
            +
                                                              offset_upperbound).float()
         
     | 
| 224 | 
         
            +
                        offset_t_weights[inds, t_label[:,
         
     | 
| 225 | 
         
            +
                                                       k]] = (t_topk[:, k] <
         
     | 
| 226 | 
         
            +
                                                              offset_upperbound).float()
         
     | 
| 227 | 
         
            +
                        offset_d_weights[inds, d_label[:,
         
     | 
| 228 | 
         
            +
                                                       k]] = (d_topk[:, k] <
         
     | 
| 229 | 
         
            +
                                                              offset_upperbound).float()
         
     | 
| 230 | 
         
            +
                    else:
         
     | 
| 231 | 
         
            +
                        offset_l_weights[inds, l_label[:, k]] = 1.0
         
     | 
| 232 | 
         
            +
                        offset_r_weights[inds, r_label[:, k]] = 1.0
         
     | 
| 233 | 
         
            +
                        offset_t_weights[inds, t_label[:, k]] = 1.0
         
     | 
| 234 | 
         
            +
                        offset_d_weights[inds, d_label[:, k]] = 1.0
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
         
     | 
| 237 | 
         
            +
                offsets_weights = torch.cat([
         
     | 
| 238 | 
         
            +
                    offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
         
     | 
| 239 | 
         
            +
                ],
         
     | 
| 240 | 
         
            +
                                            dim=-1)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                # generate bucket labels and weight
         
     | 
| 243 | 
         
            +
                side_num = int(np.ceil(num_buckets / 2.0))
         
     | 
| 244 | 
         
            +
                labels = torch.stack(
         
     | 
| 245 | 
         
            +
                    [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                batch_size = labels.size(0)
         
     | 
| 248 | 
         
            +
                bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
         
     | 
| 249 | 
         
            +
                                                                          -1).float()
         
     | 
| 250 | 
         
            +
                bucket_cls_l_weights = (l_offsets.abs() < 1).float()
         
     | 
| 251 | 
         
            +
                bucket_cls_r_weights = (r_offsets.abs() < 1).float()
         
     | 
| 252 | 
         
            +
                bucket_cls_t_weights = (t_offsets.abs() < 1).float()
         
     | 
| 253 | 
         
            +
                bucket_cls_d_weights = (d_offsets.abs() < 1).float()
         
     | 
| 254 | 
         
            +
                bucket_cls_weights = torch.cat([
         
     | 
| 255 | 
         
            +
                    bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
         
     | 
| 256 | 
         
            +
                    bucket_cls_d_weights
         
     | 
| 257 | 
         
            +
                ],
         
     | 
| 258 | 
         
            +
                                               dim=-1)
         
     | 
| 259 | 
         
            +
                # ignore second nearest buckets for cls if necessary
         
     | 
| 260 | 
         
            +
                if cls_ignore_neighbor:
         
     | 
| 261 | 
         
            +
                    bucket_cls_weights = (~((bucket_cls_weights == 1) &
         
     | 
| 262 | 
         
            +
                                            (bucket_labels == 0))).float()
         
     | 
| 263 | 
         
            +
                else:
         
     | 
| 264 | 
         
            +
                    bucket_cls_weights[:] = 1.0
         
     | 
| 265 | 
         
            +
                return offsets, offsets_weights, bucket_labels, bucket_cls_weights
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 269 | 
         
            +
            def bucket2bbox(proposals,
         
     | 
| 270 | 
         
            +
                            cls_preds,
         
     | 
| 271 | 
         
            +
                            offset_preds,
         
     | 
| 272 | 
         
            +
                            num_buckets,
         
     | 
| 273 | 
         
            +
                            scale_factor=1.0,
         
     | 
| 274 | 
         
            +
                            max_shape=None,
         
     | 
| 275 | 
         
            +
                            clip_border=True):
         
     | 
| 276 | 
         
            +
                """Apply bucketing estimation (cls preds) and fine regression (offset
         
     | 
| 277 | 
         
            +
                preds) to generate det bboxes.
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                Args:
         
     | 
| 280 | 
         
            +
                    proposals (Tensor): Boxes to be transformed. Shape (n, 4)
         
     | 
| 281 | 
         
            +
                    cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
         
     | 
| 282 | 
         
            +
                    offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
         
     | 
| 283 | 
         
            +
                    num_buckets (int): Number of buckets.
         
     | 
| 284 | 
         
            +
                    scale_factor (float): Scale factor to rescale proposals.
         
     | 
| 285 | 
         
            +
                    max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
         
     | 
| 286 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 287 | 
         
            +
                        border of the image. Defaults to True.
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                Returns:
         
     | 
| 290 | 
         
            +
                    tuple[Tensor]: (bboxes, loc_confidence).
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                        - bboxes: predicted bboxes. Shape (n, 4)
         
     | 
| 293 | 
         
            +
                        - loc_confidence: localization confidence of predicted bboxes.
         
     | 
| 294 | 
         
            +
                            Shape (n,).
         
     | 
| 295 | 
         
            +
                """
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                side_num = int(np.ceil(num_buckets / 2.0))
         
     | 
| 298 | 
         
            +
                cls_preds = cls_preds.view(-1, side_num)
         
     | 
| 299 | 
         
            +
                offset_preds = offset_preds.view(-1, side_num)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                scores = F.softmax(cls_preds, dim=1)
         
     | 
| 302 | 
         
            +
                score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                rescaled_proposals = bbox_rescale(proposals, scale_factor)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
         
     | 
| 307 | 
         
            +
                ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
         
     | 
| 308 | 
         
            +
                px1 = rescaled_proposals[..., 0]
         
     | 
| 309 | 
         
            +
                py1 = rescaled_proposals[..., 1]
         
     | 
| 310 | 
         
            +
                px2 = rescaled_proposals[..., 2]
         
     | 
| 311 | 
         
            +
                py2 = rescaled_proposals[..., 3]
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                bucket_w = pw / num_buckets
         
     | 
| 314 | 
         
            +
                bucket_h = ph / num_buckets
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                score_inds_l = score_label[0::4, 0]
         
     | 
| 317 | 
         
            +
                score_inds_r = score_label[1::4, 0]
         
     | 
| 318 | 
         
            +
                score_inds_t = score_label[2::4, 0]
         
     | 
| 319 | 
         
            +
                score_inds_d = score_label[3::4, 0]
         
     | 
| 320 | 
         
            +
                l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
         
     | 
| 321 | 
         
            +
                r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
         
     | 
| 322 | 
         
            +
                t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
         
     | 
| 323 | 
         
            +
                d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                offsets = offset_preds.view(-1, 4, side_num)
         
     | 
| 326 | 
         
            +
                inds = torch.arange(proposals.size(0)).to(proposals).long()
         
     | 
| 327 | 
         
            +
                l_offsets = offsets[:, 0, :][inds, score_inds_l]
         
     | 
| 328 | 
         
            +
                r_offsets = offsets[:, 1, :][inds, score_inds_r]
         
     | 
| 329 | 
         
            +
                t_offsets = offsets[:, 2, :][inds, score_inds_t]
         
     | 
| 330 | 
         
            +
                d_offsets = offsets[:, 3, :][inds, score_inds_d]
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                x1 = l_buckets - l_offsets * bucket_w
         
     | 
| 333 | 
         
            +
                x2 = r_buckets - r_offsets * bucket_w
         
     | 
| 334 | 
         
            +
                y1 = t_buckets - t_offsets * bucket_h
         
     | 
| 335 | 
         
            +
                y2 = d_buckets - d_offsets * bucket_h
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                if clip_border and max_shape is not None:
         
     | 
| 338 | 
         
            +
                    x1 = x1.clamp(min=0, max=max_shape[1] - 1)
         
     | 
| 339 | 
         
            +
                    y1 = y1.clamp(min=0, max=max_shape[0] - 1)
         
     | 
| 340 | 
         
            +
                    x2 = x2.clamp(min=0, max=max_shape[1] - 1)
         
     | 
| 341 | 
         
            +
                    y2 = y2.clamp(min=0, max=max_shape[0] - 1)
         
     | 
| 342 | 
         
            +
                bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
         
     | 
| 343 | 
         
            +
                                   dim=-1)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                # bucketing guided rescoring
         
     | 
| 346 | 
         
            +
                loc_confidence = score_topk[:, 0]
         
     | 
| 347 | 
         
            +
                top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
         
     | 
| 348 | 
         
            +
                loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
         
     | 
| 349 | 
         
            +
                loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                return bboxes, loc_confidence
         
     | 
    	
        mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,392 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import mmcv
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 9 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 13 | 
         
            +
            class DeltaXYWHBBoxCoder(BaseBBoxCoder):
         
     | 
| 14 | 
         
            +
                """Delta XYWH BBox coder.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_,
         
     | 
| 17 | 
         
            +
                this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and
         
     | 
| 18 | 
         
            +
                decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                Args:
         
     | 
| 21 | 
         
            +
                    target_means (Sequence[float]): Denormalizing means of target for
         
     | 
| 22 | 
         
            +
                        delta coordinates
         
     | 
| 23 | 
         
            +
                    target_stds (Sequence[float]): Denormalizing standard deviation of
         
     | 
| 24 | 
         
            +
                        target for delta coordinates
         
     | 
| 25 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 26 | 
         
            +
                        border of the image. Defaults to True.
         
     | 
| 27 | 
         
            +
                    add_ctr_clamp (bool): Whether to add center clamp, when added, the
         
     | 
| 28 | 
         
            +
                        predicted box is clamped is its center is too far away from
         
     | 
| 29 | 
         
            +
                        the original anchor's center. Only used by YOLOF. Default False.
         
     | 
| 30 | 
         
            +
                    ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
         
     | 
| 31 | 
         
            +
                        Default 32.
         
     | 
| 32 | 
         
            +
                """
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def __init__(self,
         
     | 
| 35 | 
         
            +
                             target_means=(0., 0., 0., 0.),
         
     | 
| 36 | 
         
            +
                             target_stds=(1., 1., 1., 1.),
         
     | 
| 37 | 
         
            +
                             clip_border=True,
         
     | 
| 38 | 
         
            +
                             add_ctr_clamp=False,
         
     | 
| 39 | 
         
            +
                             ctr_clamp=32):
         
     | 
| 40 | 
         
            +
                    super(BaseBBoxCoder, self).__init__()
         
     | 
| 41 | 
         
            +
                    self.means = target_means
         
     | 
| 42 | 
         
            +
                    self.stds = target_stds
         
     | 
| 43 | 
         
            +
                    self.clip_border = clip_border
         
     | 
| 44 | 
         
            +
                    self.add_ctr_clamp = add_ctr_clamp
         
     | 
| 45 | 
         
            +
                    self.ctr_clamp = ctr_clamp
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def encode(self, bboxes, gt_bboxes):
         
     | 
| 48 | 
         
            +
                    """Get box regression transformation deltas that can be used to
         
     | 
| 49 | 
         
            +
                    transform the ``bboxes`` into the ``gt_bboxes``.
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    Args:
         
     | 
| 52 | 
         
            +
                        bboxes (torch.Tensor): Source boxes, e.g., object proposals.
         
     | 
| 53 | 
         
            +
                        gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
         
     | 
| 54 | 
         
            +
                            ground-truth boxes.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    Returns:
         
     | 
| 57 | 
         
            +
                        torch.Tensor: Box transformation deltas
         
     | 
| 58 | 
         
            +
                    """
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    assert bboxes.size(0) == gt_bboxes.size(0)
         
     | 
| 61 | 
         
            +
                    assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
         
     | 
| 62 | 
         
            +
                    encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds)
         
     | 
| 63 | 
         
            +
                    return encoded_bboxes
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def decode(self,
         
     | 
| 66 | 
         
            +
                           bboxes,
         
     | 
| 67 | 
         
            +
                           pred_bboxes,
         
     | 
| 68 | 
         
            +
                           max_shape=None,
         
     | 
| 69 | 
         
            +
                           wh_ratio_clip=16 / 1000):
         
     | 
| 70 | 
         
            +
                    """Apply transformation `pred_bboxes` to `boxes`.
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    Args:
         
     | 
| 73 | 
         
            +
                        bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
         
     | 
| 74 | 
         
            +
                        pred_bboxes (Tensor): Encoded offsets with respect to each roi.
         
     | 
| 75 | 
         
            +
                           Has shape (B, N, num_classes * 4) or (B, N, 4) or
         
     | 
| 76 | 
         
            +
                           (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
         
     | 
| 77 | 
         
            +
                           when rois is a grid of anchors.Offset encoding follows [1]_.
         
     | 
| 78 | 
         
            +
                        max_shape (Sequence[int] or torch.Tensor or Sequence[
         
     | 
| 79 | 
         
            +
                           Sequence[int]],optional): Maximum bounds for boxes, specifies
         
     | 
| 80 | 
         
            +
                           (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
         
     | 
| 81 | 
         
            +
                           the max_shape should be a Sequence[Sequence[int]]
         
     | 
| 82 | 
         
            +
                           and the length of max_shape should also be B.
         
     | 
| 83 | 
         
            +
                        wh_ratio_clip (float, optional): The allowed ratio between
         
     | 
| 84 | 
         
            +
                            width and height.
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    Returns:
         
     | 
| 87 | 
         
            +
                        torch.Tensor: Decoded boxes.
         
     | 
| 88 | 
         
            +
                    """
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    assert pred_bboxes.size(0) == bboxes.size(0)
         
     | 
| 91 | 
         
            +
                    if pred_bboxes.ndim == 3:
         
     | 
| 92 | 
         
            +
                        assert pred_bboxes.size(1) == bboxes.size(1)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export():
         
     | 
| 95 | 
         
            +
                        # single image decode
         
     | 
| 96 | 
         
            +
                        decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means,
         
     | 
| 97 | 
         
            +
                                                    self.stds, max_shape, wh_ratio_clip,
         
     | 
| 98 | 
         
            +
                                                    self.clip_border, self.add_ctr_clamp,
         
     | 
| 99 | 
         
            +
                                                    self.ctr_clamp)
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export():
         
     | 
| 102 | 
         
            +
                            warnings.warn(
         
     | 
| 103 | 
         
            +
                                'DeprecationWarning: onnx_delta2bbox is deprecated '
         
     | 
| 104 | 
         
            +
                                'in the case of batch decoding and non-ONNX, '
         
     | 
| 105 | 
         
            +
                                'please use “delta2bbox” instead. In order to improve '
         
     | 
| 106 | 
         
            +
                                'the decoding speed, the batch function will no '
         
     | 
| 107 | 
         
            +
                                'longer be supported. ')
         
     | 
| 108 | 
         
            +
                        decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means,
         
     | 
| 109 | 
         
            +
                                                         self.stds, max_shape,
         
     | 
| 110 | 
         
            +
                                                         wh_ratio_clip, self.clip_border,
         
     | 
| 111 | 
         
            +
                                                         self.add_ctr_clamp,
         
     | 
| 112 | 
         
            +
                                                         self.ctr_clamp)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    return decoded_bboxes
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 118 | 
         
            +
            def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
         
     | 
| 119 | 
         
            +
                """Compute deltas of proposals w.r.t. gt.
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                We usually compute the deltas of x, y, w, h of proposals w.r.t ground
         
     | 
| 122 | 
         
            +
                truth bboxes to get regression target.
         
     | 
| 123 | 
         
            +
                This is the inverse function of :func:`delta2bbox`.
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                Args:
         
     | 
| 126 | 
         
            +
                    proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
         
     | 
| 127 | 
         
            +
                    gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
         
     | 
| 128 | 
         
            +
                    means (Sequence[float]): Denormalizing means for delta coordinates
         
     | 
| 129 | 
         
            +
                    stds (Sequence[float]): Denormalizing standard deviation for delta
         
     | 
| 130 | 
         
            +
                        coordinates
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                Returns:
         
     | 
| 133 | 
         
            +
                    Tensor: deltas with shape (N, 4), where columns represent dx, dy,
         
     | 
| 134 | 
         
            +
                        dw, dh.
         
     | 
| 135 | 
         
            +
                """
         
     | 
| 136 | 
         
            +
                assert proposals.size() == gt.size()
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                proposals = proposals.float()
         
     | 
| 139 | 
         
            +
                gt = gt.float()
         
     | 
| 140 | 
         
            +
                px = (proposals[..., 0] + proposals[..., 2]) * 0.5
         
     | 
| 141 | 
         
            +
                py = (proposals[..., 1] + proposals[..., 3]) * 0.5
         
     | 
| 142 | 
         
            +
                pw = proposals[..., 2] - proposals[..., 0]
         
     | 
| 143 | 
         
            +
                ph = proposals[..., 3] - proposals[..., 1]
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                gx = (gt[..., 0] + gt[..., 2]) * 0.5
         
     | 
| 146 | 
         
            +
                gy = (gt[..., 1] + gt[..., 3]) * 0.5
         
     | 
| 147 | 
         
            +
                gw = gt[..., 2] - gt[..., 0]
         
     | 
| 148 | 
         
            +
                gh = gt[..., 3] - gt[..., 1]
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                dx = (gx - px) / pw
         
     | 
| 151 | 
         
            +
                dy = (gy - py) / ph
         
     | 
| 152 | 
         
            +
                dw = torch.log(gw / pw)
         
     | 
| 153 | 
         
            +
                dh = torch.log(gh / ph)
         
     | 
| 154 | 
         
            +
                deltas = torch.stack([dx, dy, dw, dh], dim=-1)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                means = deltas.new_tensor(means).unsqueeze(0)
         
     | 
| 157 | 
         
            +
                stds = deltas.new_tensor(stds).unsqueeze(0)
         
     | 
| 158 | 
         
            +
                deltas = deltas.sub_(means).div_(stds)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                return deltas
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 164 | 
         
            +
            def delta2bbox(rois,
         
     | 
| 165 | 
         
            +
                           deltas,
         
     | 
| 166 | 
         
            +
                           means=(0., 0., 0., 0.),
         
     | 
| 167 | 
         
            +
                           stds=(1., 1., 1., 1.),
         
     | 
| 168 | 
         
            +
                           max_shape=None,
         
     | 
| 169 | 
         
            +
                           wh_ratio_clip=16 / 1000,
         
     | 
| 170 | 
         
            +
                           clip_border=True,
         
     | 
| 171 | 
         
            +
                           add_ctr_clamp=False,
         
     | 
| 172 | 
         
            +
                           ctr_clamp=32):
         
     | 
| 173 | 
         
            +
                """Apply deltas to shift/scale base boxes.
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                Typically the rois are anchor or proposed bounding boxes and the deltas are
         
     | 
| 176 | 
         
            +
                network outputs used to shift/scale those boxes.
         
     | 
| 177 | 
         
            +
                This is the inverse function of :func:`bbox2delta`.
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                Args:
         
     | 
| 180 | 
         
            +
                    rois (Tensor): Boxes to be transformed. Has shape (N, 4).
         
     | 
| 181 | 
         
            +
                    deltas (Tensor): Encoded offsets relative to each roi.
         
     | 
| 182 | 
         
            +
                        Has shape (N, num_classes * 4) or (N, 4). Note
         
     | 
| 183 | 
         
            +
                        N = num_base_anchors * W * H, when rois is a grid of
         
     | 
| 184 | 
         
            +
                        anchors. Offset encoding follows [1]_.
         
     | 
| 185 | 
         
            +
                    means (Sequence[float]): Denormalizing means for delta coordinates.
         
     | 
| 186 | 
         
            +
                        Default (0., 0., 0., 0.).
         
     | 
| 187 | 
         
            +
                    stds (Sequence[float]): Denormalizing standard deviation for delta
         
     | 
| 188 | 
         
            +
                        coordinates. Default (1., 1., 1., 1.).
         
     | 
| 189 | 
         
            +
                    max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
         
     | 
| 190 | 
         
            +
                       (H, W). Default None.
         
     | 
| 191 | 
         
            +
                    wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
         
     | 
| 192 | 
         
            +
                        16 / 1000.
         
     | 
| 193 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 194 | 
         
            +
                        border of the image. Default True.
         
     | 
| 195 | 
         
            +
                    add_ctr_clamp (bool): Whether to add center clamp. When set to True,
         
     | 
| 196 | 
         
            +
                        the center of the prediction bounding box will be clamped to
         
     | 
| 197 | 
         
            +
                        avoid being too far away from the center of the anchor.
         
     | 
| 198 | 
         
            +
                        Only used by YOLOF. Default False.
         
     | 
| 199 | 
         
            +
                    ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
         
     | 
| 200 | 
         
            +
                        Default 32.
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                Returns:
         
     | 
| 203 | 
         
            +
                    Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
         
     | 
| 204 | 
         
            +
                       represent tl_x, tl_y, br_x, br_y.
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                References:
         
     | 
| 207 | 
         
            +
                    .. [1] https://arxiv.org/abs/1311.2524
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                Example:
         
     | 
| 210 | 
         
            +
                    >>> rois = torch.Tensor([[ 0.,  0.,  1.,  1.],
         
     | 
| 211 | 
         
            +
                    >>>                      [ 0.,  0.,  1.,  1.],
         
     | 
| 212 | 
         
            +
                    >>>                      [ 0.,  0.,  1.,  1.],
         
     | 
| 213 | 
         
            +
                    >>>                      [ 5.,  5.,  5.,  5.]])
         
     | 
| 214 | 
         
            +
                    >>> deltas = torch.Tensor([[  0.,   0.,   0.,   0.],
         
     | 
| 215 | 
         
            +
                    >>>                        [  1.,   1.,   1.,   1.],
         
     | 
| 216 | 
         
            +
                    >>>                        [  0.,   0.,   2.,  -1.],
         
     | 
| 217 | 
         
            +
                    >>>                        [ 0.7, -1.9, -0.5,  0.3]])
         
     | 
| 218 | 
         
            +
                    >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
         
     | 
| 219 | 
         
            +
                    tensor([[0.0000, 0.0000, 1.0000, 1.0000],
         
     | 
| 220 | 
         
            +
                            [0.1409, 0.1409, 2.8591, 2.8591],
         
     | 
| 221 | 
         
            +
                            [0.0000, 0.3161, 4.1945, 0.6839],
         
     | 
| 222 | 
         
            +
                            [5.0000, 5.0000, 5.0000, 5.0000]])
         
     | 
| 223 | 
         
            +
                """
         
     | 
| 224 | 
         
            +
                num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
         
     | 
| 225 | 
         
            +
                if num_bboxes == 0:
         
     | 
| 226 | 
         
            +
                    return deltas
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                deltas = deltas.reshape(-1, 4)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                means = deltas.new_tensor(means).view(1, -1)
         
     | 
| 231 | 
         
            +
                stds = deltas.new_tensor(stds).view(1, -1)
         
     | 
| 232 | 
         
            +
                denorm_deltas = deltas * stds + means
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                dxy = denorm_deltas[:, :2]
         
     | 
| 235 | 
         
            +
                dwh = denorm_deltas[:, 2:]
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                # Compute width/height of each roi
         
     | 
| 238 | 
         
            +
                rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
         
     | 
| 239 | 
         
            +
                pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5)
         
     | 
| 240 | 
         
            +
                pwh = (rois_[:, 2:] - rois_[:, :2])
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                dxy_wh = pwh * dxy
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                max_ratio = np.abs(np.log(wh_ratio_clip))
         
     | 
| 245 | 
         
            +
                if add_ctr_clamp:
         
     | 
| 246 | 
         
            +
                    dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
         
     | 
| 247 | 
         
            +
                    dwh = torch.clamp(dwh, max=max_ratio)
         
     | 
| 248 | 
         
            +
                else:
         
     | 
| 249 | 
         
            +
                    dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                gxy = pxy + dxy_wh
         
     | 
| 252 | 
         
            +
                gwh = pwh * dwh.exp()
         
     | 
| 253 | 
         
            +
                x1y1 = gxy - (gwh * 0.5)
         
     | 
| 254 | 
         
            +
                x2y2 = gxy + (gwh * 0.5)
         
     | 
| 255 | 
         
            +
                bboxes = torch.cat([x1y1, x2y2], dim=-1)
         
     | 
| 256 | 
         
            +
                if clip_border and max_shape is not None:
         
     | 
| 257 | 
         
            +
                    bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
         
     | 
| 258 | 
         
            +
                    bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
         
     | 
| 259 | 
         
            +
                bboxes = bboxes.reshape(num_bboxes, -1)
         
     | 
| 260 | 
         
            +
                return bboxes
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
            def onnx_delta2bbox(rois,
         
     | 
| 264 | 
         
            +
                                deltas,
         
     | 
| 265 | 
         
            +
                                means=(0., 0., 0., 0.),
         
     | 
| 266 | 
         
            +
                                stds=(1., 1., 1., 1.),
         
     | 
| 267 | 
         
            +
                                max_shape=None,
         
     | 
| 268 | 
         
            +
                                wh_ratio_clip=16 / 1000,
         
     | 
| 269 | 
         
            +
                                clip_border=True,
         
     | 
| 270 | 
         
            +
                                add_ctr_clamp=False,
         
     | 
| 271 | 
         
            +
                                ctr_clamp=32):
         
     | 
| 272 | 
         
            +
                """Apply deltas to shift/scale base boxes.
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                Typically the rois are anchor or proposed bounding boxes and the deltas are
         
     | 
| 275 | 
         
            +
                network outputs used to shift/scale those boxes.
         
     | 
| 276 | 
         
            +
                This is the inverse function of :func:`bbox2delta`.
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                Args:
         
     | 
| 279 | 
         
            +
                    rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
         
     | 
| 280 | 
         
            +
                    deltas (Tensor): Encoded offsets with respect to each roi.
         
     | 
| 281 | 
         
            +
                        Has shape (B, N, num_classes * 4) or (B, N, 4) or
         
     | 
| 282 | 
         
            +
                        (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
         
     | 
| 283 | 
         
            +
                        when rois is a grid of anchors.Offset encoding follows [1]_.
         
     | 
| 284 | 
         
            +
                    means (Sequence[float]): Denormalizing means for delta coordinates.
         
     | 
| 285 | 
         
            +
                        Default (0., 0., 0., 0.).
         
     | 
| 286 | 
         
            +
                    stds (Sequence[float]): Denormalizing standard deviation for delta
         
     | 
| 287 | 
         
            +
                        coordinates. Default (1., 1., 1., 1.).
         
     | 
| 288 | 
         
            +
                    max_shape (Sequence[int] or torch.Tensor or Sequence[
         
     | 
| 289 | 
         
            +
                        Sequence[int]],optional): Maximum bounds for boxes, specifies
         
     | 
| 290 | 
         
            +
                        (H, W, C) or (H, W). If rois shape is (B, N, 4), then
         
     | 
| 291 | 
         
            +
                        the max_shape should be a Sequence[Sequence[int]]
         
     | 
| 292 | 
         
            +
                        and the length of max_shape should also be B. Default None.
         
     | 
| 293 | 
         
            +
                    wh_ratio_clip (float): Maximum aspect ratio for boxes.
         
     | 
| 294 | 
         
            +
                        Default 16 / 1000.
         
     | 
| 295 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 296 | 
         
            +
                        border of the image. Default True.
         
     | 
| 297 | 
         
            +
                    add_ctr_clamp (bool): Whether to add center clamp, when added, the
         
     | 
| 298 | 
         
            +
                        predicted box is clamped is its center is too far away from
         
     | 
| 299 | 
         
            +
                        the original anchor's center. Only used by YOLOF. Default False.
         
     | 
| 300 | 
         
            +
                    ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
         
     | 
| 301 | 
         
            +
                        Default 32.
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                Returns:
         
     | 
| 304 | 
         
            +
                    Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
         
     | 
| 305 | 
         
            +
                       (N, num_classes * 4) or (N, 4), where 4 represent
         
     | 
| 306 | 
         
            +
                       tl_x, tl_y, br_x, br_y.
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                References:
         
     | 
| 309 | 
         
            +
                    .. [1] https://arxiv.org/abs/1311.2524
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                Example:
         
     | 
| 312 | 
         
            +
                    >>> rois = torch.Tensor([[ 0.,  0.,  1.,  1.],
         
     | 
| 313 | 
         
            +
                    >>>                      [ 0.,  0.,  1.,  1.],
         
     | 
| 314 | 
         
            +
                    >>>                      [ 0.,  0.,  1.,  1.],
         
     | 
| 315 | 
         
            +
                    >>>                      [ 5.,  5.,  5.,  5.]])
         
     | 
| 316 | 
         
            +
                    >>> deltas = torch.Tensor([[  0.,   0.,   0.,   0.],
         
     | 
| 317 | 
         
            +
                    >>>                        [  1.,   1.,   1.,   1.],
         
     | 
| 318 | 
         
            +
                    >>>                        [  0.,   0.,   2.,  -1.],
         
     | 
| 319 | 
         
            +
                    >>>                        [ 0.7, -1.9, -0.5,  0.3]])
         
     | 
| 320 | 
         
            +
                    >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
         
     | 
| 321 | 
         
            +
                    tensor([[0.0000, 0.0000, 1.0000, 1.0000],
         
     | 
| 322 | 
         
            +
                            [0.1409, 0.1409, 2.8591, 2.8591],
         
     | 
| 323 | 
         
            +
                            [0.0000, 0.3161, 4.1945, 0.6839],
         
     | 
| 324 | 
         
            +
                            [5.0000, 5.0000, 5.0000, 5.0000]])
         
     | 
| 325 | 
         
            +
                """
         
     | 
| 326 | 
         
            +
                means = deltas.new_tensor(means).view(1,
         
     | 
| 327 | 
         
            +
                                                      -1).repeat(1,
         
     | 
| 328 | 
         
            +
                                                                 deltas.size(-1) // 4)
         
     | 
| 329 | 
         
            +
                stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
         
     | 
| 330 | 
         
            +
                denorm_deltas = deltas * stds + means
         
     | 
| 331 | 
         
            +
                dx = denorm_deltas[..., 0::4]
         
     | 
| 332 | 
         
            +
                dy = denorm_deltas[..., 1::4]
         
     | 
| 333 | 
         
            +
                dw = denorm_deltas[..., 2::4]
         
     | 
| 334 | 
         
            +
                dh = denorm_deltas[..., 3::4]
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                x1, y1 = rois[..., 0], rois[..., 1]
         
     | 
| 337 | 
         
            +
                x2, y2 = rois[..., 2], rois[..., 3]
         
     | 
| 338 | 
         
            +
                # Compute center of each roi
         
     | 
| 339 | 
         
            +
                px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
         
     | 
| 340 | 
         
            +
                py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
         
     | 
| 341 | 
         
            +
                # Compute width/height of each roi
         
     | 
| 342 | 
         
            +
                pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
         
     | 
| 343 | 
         
            +
                ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                dx_width = pw * dx
         
     | 
| 346 | 
         
            +
                dy_height = ph * dy
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                max_ratio = np.abs(np.log(wh_ratio_clip))
         
     | 
| 349 | 
         
            +
                if add_ctr_clamp:
         
     | 
| 350 | 
         
            +
                    dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
         
     | 
| 351 | 
         
            +
                    dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
         
     | 
| 352 | 
         
            +
                    dw = torch.clamp(dw, max=max_ratio)
         
     | 
| 353 | 
         
            +
                    dh = torch.clamp(dh, max=max_ratio)
         
     | 
| 354 | 
         
            +
                else:
         
     | 
| 355 | 
         
            +
                    dw = dw.clamp(min=-max_ratio, max=max_ratio)
         
     | 
| 356 | 
         
            +
                    dh = dh.clamp(min=-max_ratio, max=max_ratio)
         
     | 
| 357 | 
         
            +
                # Use exp(network energy) to enlarge/shrink each roi
         
     | 
| 358 | 
         
            +
                gw = pw * dw.exp()
         
     | 
| 359 | 
         
            +
                gh = ph * dh.exp()
         
     | 
| 360 | 
         
            +
                # Use network energy to shift the center of each roi
         
     | 
| 361 | 
         
            +
                gx = px + dx_width
         
     | 
| 362 | 
         
            +
                gy = py + dy_height
         
     | 
| 363 | 
         
            +
                # Convert center-xy/width/height to top-left, bottom-right
         
     | 
| 364 | 
         
            +
                x1 = gx - gw * 0.5
         
     | 
| 365 | 
         
            +
                y1 = gy - gh * 0.5
         
     | 
| 366 | 
         
            +
                x2 = gx + gw * 0.5
         
     | 
| 367 | 
         
            +
                y2 = gy + gh * 0.5
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                if clip_border and max_shape is not None:
         
     | 
| 372 | 
         
            +
                    # clip bboxes with dynamic `min` and `max` for onnx
         
     | 
| 373 | 
         
            +
                    if torch.onnx.is_in_onnx_export():
         
     | 
| 374 | 
         
            +
                        from mmdet.core.export import dynamic_clip_for_onnx
         
     | 
| 375 | 
         
            +
                        x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
         
     | 
| 376 | 
         
            +
                        bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
         
     | 
| 377 | 
         
            +
                        return bboxes
         
     | 
| 378 | 
         
            +
                    if not isinstance(max_shape, torch.Tensor):
         
     | 
| 379 | 
         
            +
                        max_shape = x1.new_tensor(max_shape)
         
     | 
| 380 | 
         
            +
                    max_shape = max_shape[..., :2].type_as(x1)
         
     | 
| 381 | 
         
            +
                    if max_shape.ndim == 2:
         
     | 
| 382 | 
         
            +
                        assert bboxes.ndim == 3
         
     | 
| 383 | 
         
            +
                        assert max_shape.size(0) == bboxes.size(0)
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                    min_xy = x1.new_tensor(0)
         
     | 
| 386 | 
         
            +
                    max_xy = torch.cat(
         
     | 
| 387 | 
         
            +
                        [max_shape] * (deltas.size(-1) // 2),
         
     | 
| 388 | 
         
            +
                        dim=-1).flip(-1).unsqueeze(-2)
         
     | 
| 389 | 
         
            +
                    bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
         
     | 
| 390 | 
         
            +
                    bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                return bboxes
         
     | 
    	
        mmdet/core/bbox/coder/distance_point_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,63 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 3 | 
         
            +
            from ..transforms import bbox2distance, distance2bbox
         
     | 
| 4 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 8 | 
         
            +
            class DistancePointBBoxCoder(BaseBBoxCoder):
         
     | 
| 9 | 
         
            +
                """Distance Point BBox coder.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
         
     | 
| 12 | 
         
            +
                right) and decode it back to the original.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Args:
         
     | 
| 15 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 16 | 
         
            +
                        border of the image. Defaults to True.
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def __init__(self, clip_border=True):
         
     | 
| 20 | 
         
            +
                    super(BaseBBoxCoder, self).__init__()
         
     | 
| 21 | 
         
            +
                    self.clip_border = clip_border
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def encode(self, points, gt_bboxes, max_dis=None, eps=0.1):
         
     | 
| 24 | 
         
            +
                    """Encode bounding box to distances.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    Args:
         
     | 
| 27 | 
         
            +
                        points (Tensor): Shape (N, 2), The format is [x, y].
         
     | 
| 28 | 
         
            +
                        gt_bboxes (Tensor): Shape (N, 4), The format is "xyxy"
         
     | 
| 29 | 
         
            +
                        max_dis (float): Upper bound of the distance. Default None.
         
     | 
| 30 | 
         
            +
                        eps (float): a small value to ensure target < max_dis, instead <=.
         
     | 
| 31 | 
         
            +
                            Default 0.1.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    Returns:
         
     | 
| 34 | 
         
            +
                        Tensor: Box transformation deltas. The shape is (N, 4).
         
     | 
| 35 | 
         
            +
                    """
         
     | 
| 36 | 
         
            +
                    assert points.size(0) == gt_bboxes.size(0)
         
     | 
| 37 | 
         
            +
                    assert points.size(-1) == 2
         
     | 
| 38 | 
         
            +
                    assert gt_bboxes.size(-1) == 4
         
     | 
| 39 | 
         
            +
                    return bbox2distance(points, gt_bboxes, max_dis, eps)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def decode(self, points, pred_bboxes, max_shape=None):
         
     | 
| 42 | 
         
            +
                    """Decode distance prediction to bounding box.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    Args:
         
     | 
| 45 | 
         
            +
                        points (Tensor): Shape (B, N, 2) or (N, 2).
         
     | 
| 46 | 
         
            +
                        pred_bboxes (Tensor): Distance from the given point to 4
         
     | 
| 47 | 
         
            +
                            boundaries (left, top, right, bottom). Shape (B, N, 4)
         
     | 
| 48 | 
         
            +
                            or (N, 4)
         
     | 
| 49 | 
         
            +
                        max_shape (Sequence[int] or torch.Tensor or Sequence[
         
     | 
| 50 | 
         
            +
                            Sequence[int]],optional): Maximum bounds for boxes, specifies
         
     | 
| 51 | 
         
            +
                            (H, W, C) or (H, W). If priors shape is (B, N, 4), then
         
     | 
| 52 | 
         
            +
                            the max_shape should be a Sequence[Sequence[int]],
         
     | 
| 53 | 
         
            +
                            and the length of max_shape should also be B.
         
     | 
| 54 | 
         
            +
                            Default None.
         
     | 
| 55 | 
         
            +
                    Returns:
         
     | 
| 56 | 
         
            +
                        Tensor: Boxes with shape (N, 4) or (B, N, 4)
         
     | 
| 57 | 
         
            +
                    """
         
     | 
| 58 | 
         
            +
                    assert points.size(0) == pred_bboxes.size(0)
         
     | 
| 59 | 
         
            +
                    assert points.size(-1) == 2
         
     | 
| 60 | 
         
            +
                    assert pred_bboxes.size(-1) == 4
         
     | 
| 61 | 
         
            +
                    if self.clip_border is False:
         
     | 
| 62 | 
         
            +
                        max_shape = None
         
     | 
| 63 | 
         
            +
                    return distance2bbox(points, pred_bboxes, max_shape)
         
     | 
    	
        mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,216 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import mmcv
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 7 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 11 | 
         
            +
            class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder):
         
     | 
| 12 | 
         
            +
                """Legacy Delta XYWH BBox coder used in MMDet V1.x.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2,
         
     | 
| 15 | 
         
            +
                y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh)
         
     | 
| 16 | 
         
            +
                back to original bbox (x1, y1, x2, y2).
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                Note:
         
     | 
| 19 | 
         
            +
                    The main difference between :class`LegacyDeltaXYWHBBoxCoder` and
         
     | 
| 20 | 
         
            +
                    :class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and
         
     | 
| 21 | 
         
            +
                    height calculation. We suggest to only use this coder when testing with
         
     | 
| 22 | 
         
            +
                    MMDet V1.x models.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                References:
         
     | 
| 25 | 
         
            +
                    .. [1] https://arxiv.org/abs/1311.2524
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Args:
         
     | 
| 28 | 
         
            +
                    target_means (Sequence[float]): denormalizing means of target for
         
     | 
| 29 | 
         
            +
                        delta coordinates
         
     | 
| 30 | 
         
            +
                    target_stds (Sequence[float]): denormalizing standard deviation of
         
     | 
| 31 | 
         
            +
                        target for delta coordinates
         
     | 
| 32 | 
         
            +
                """
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def __init__(self,
         
     | 
| 35 | 
         
            +
                             target_means=(0., 0., 0., 0.),
         
     | 
| 36 | 
         
            +
                             target_stds=(1., 1., 1., 1.)):
         
     | 
| 37 | 
         
            +
                    super(BaseBBoxCoder, self).__init__()
         
     | 
| 38 | 
         
            +
                    self.means = target_means
         
     | 
| 39 | 
         
            +
                    self.stds = target_stds
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def encode(self, bboxes, gt_bboxes):
         
     | 
| 42 | 
         
            +
                    """Get box regression transformation deltas that can be used to
         
     | 
| 43 | 
         
            +
                    transform the ``bboxes`` into the ``gt_bboxes``.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    Args:
         
     | 
| 46 | 
         
            +
                        bboxes (torch.Tensor): source boxes, e.g., object proposals.
         
     | 
| 47 | 
         
            +
                        gt_bboxes (torch.Tensor): target of the transformation, e.g.,
         
     | 
| 48 | 
         
            +
                            ground-truth boxes.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    Returns:
         
     | 
| 51 | 
         
            +
                        torch.Tensor: Box transformation deltas
         
     | 
| 52 | 
         
            +
                    """
         
     | 
| 53 | 
         
            +
                    assert bboxes.size(0) == gt_bboxes.size(0)
         
     | 
| 54 | 
         
            +
                    assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
         
     | 
| 55 | 
         
            +
                    encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means,
         
     | 
| 56 | 
         
            +
                                                       self.stds)
         
     | 
| 57 | 
         
            +
                    return encoded_bboxes
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def decode(self,
         
     | 
| 60 | 
         
            +
                           bboxes,
         
     | 
| 61 | 
         
            +
                           pred_bboxes,
         
     | 
| 62 | 
         
            +
                           max_shape=None,
         
     | 
| 63 | 
         
            +
                           wh_ratio_clip=16 / 1000):
         
     | 
| 64 | 
         
            +
                    """Apply transformation `pred_bboxes` to `boxes`.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    Args:
         
     | 
| 67 | 
         
            +
                        boxes (torch.Tensor): Basic boxes.
         
     | 
| 68 | 
         
            +
                        pred_bboxes (torch.Tensor): Encoded boxes with shape
         
     | 
| 69 | 
         
            +
                        max_shape (tuple[int], optional): Maximum shape of boxes.
         
     | 
| 70 | 
         
            +
                            Defaults to None.
         
     | 
| 71 | 
         
            +
                        wh_ratio_clip (float, optional): The allowed ratio between
         
     | 
| 72 | 
         
            +
                            width and height.
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    Returns:
         
     | 
| 75 | 
         
            +
                        torch.Tensor: Decoded boxes.
         
     | 
| 76 | 
         
            +
                    """
         
     | 
| 77 | 
         
            +
                    assert pred_bboxes.size(0) == bboxes.size(0)
         
     | 
| 78 | 
         
            +
                    decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means,
         
     | 
| 79 | 
         
            +
                                                       self.stds, max_shape, wh_ratio_clip)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    return decoded_bboxes
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 85 | 
         
            +
            def legacy_bbox2delta(proposals,
         
     | 
| 86 | 
         
            +
                                  gt,
         
     | 
| 87 | 
         
            +
                                  means=(0., 0., 0., 0.),
         
     | 
| 88 | 
         
            +
                                  stds=(1., 1., 1., 1.)):
         
     | 
| 89 | 
         
            +
                """Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                We usually compute the deltas of x, y, w, h of proposals w.r.t ground
         
     | 
| 92 | 
         
            +
                truth bboxes to get regression target.
         
     | 
| 93 | 
         
            +
                This is the inverse function of `delta2bbox()`
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                Args:
         
     | 
| 96 | 
         
            +
                    proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
         
     | 
| 97 | 
         
            +
                    gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
         
     | 
| 98 | 
         
            +
                    means (Sequence[float]): Denormalizing means for delta coordinates
         
     | 
| 99 | 
         
            +
                    stds (Sequence[float]): Denormalizing standard deviation for delta
         
     | 
| 100 | 
         
            +
                        coordinates
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                Returns:
         
     | 
| 103 | 
         
            +
                    Tensor: deltas with shape (N, 4), where columns represent dx, dy,
         
     | 
| 104 | 
         
            +
                        dw, dh.
         
     | 
| 105 | 
         
            +
                """
         
     | 
| 106 | 
         
            +
                assert proposals.size() == gt.size()
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                proposals = proposals.float()
         
     | 
| 109 | 
         
            +
                gt = gt.float()
         
     | 
| 110 | 
         
            +
                px = (proposals[..., 0] + proposals[..., 2]) * 0.5
         
     | 
| 111 | 
         
            +
                py = (proposals[..., 1] + proposals[..., 3]) * 0.5
         
     | 
| 112 | 
         
            +
                pw = proposals[..., 2] - proposals[..., 0] + 1.0
         
     | 
| 113 | 
         
            +
                ph = proposals[..., 3] - proposals[..., 1] + 1.0
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                gx = (gt[..., 0] + gt[..., 2]) * 0.5
         
     | 
| 116 | 
         
            +
                gy = (gt[..., 1] + gt[..., 3]) * 0.5
         
     | 
| 117 | 
         
            +
                gw = gt[..., 2] - gt[..., 0] + 1.0
         
     | 
| 118 | 
         
            +
                gh = gt[..., 3] - gt[..., 1] + 1.0
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                dx = (gx - px) / pw
         
     | 
| 121 | 
         
            +
                dy = (gy - py) / ph
         
     | 
| 122 | 
         
            +
                dw = torch.log(gw / pw)
         
     | 
| 123 | 
         
            +
                dh = torch.log(gh / ph)
         
     | 
| 124 | 
         
            +
                deltas = torch.stack([dx, dy, dw, dh], dim=-1)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                means = deltas.new_tensor(means).unsqueeze(0)
         
     | 
| 127 | 
         
            +
                stds = deltas.new_tensor(stds).unsqueeze(0)
         
     | 
| 128 | 
         
            +
                deltas = deltas.sub_(means).div_(stds)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                return deltas
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 134 | 
         
            +
            def legacy_delta2bbox(rois,
         
     | 
| 135 | 
         
            +
                                  deltas,
         
     | 
| 136 | 
         
            +
                                  means=(0., 0., 0., 0.),
         
     | 
| 137 | 
         
            +
                                  stds=(1., 1., 1., 1.),
         
     | 
| 138 | 
         
            +
                                  max_shape=None,
         
     | 
| 139 | 
         
            +
                                  wh_ratio_clip=16 / 1000):
         
     | 
| 140 | 
         
            +
                """Apply deltas to shift/scale base boxes in the MMDet V1.x manner.
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                Typically the rois are anchor or proposed bounding boxes and the deltas are
         
     | 
| 143 | 
         
            +
                network outputs used to shift/scale those boxes.
         
     | 
| 144 | 
         
            +
                This is the inverse function of `bbox2delta()`
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                Args:
         
     | 
| 147 | 
         
            +
                    rois (Tensor): Boxes to be transformed. Has shape (N, 4)
         
     | 
| 148 | 
         
            +
                    deltas (Tensor): Encoded offsets with respect to each roi.
         
     | 
| 149 | 
         
            +
                        Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
         
     | 
| 150 | 
         
            +
                        rois is a grid of anchors. Offset encoding follows [1]_.
         
     | 
| 151 | 
         
            +
                    means (Sequence[float]): Denormalizing means for delta coordinates
         
     | 
| 152 | 
         
            +
                    stds (Sequence[float]): Denormalizing standard deviation for delta
         
     | 
| 153 | 
         
            +
                        coordinates
         
     | 
| 154 | 
         
            +
                    max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
         
     | 
| 155 | 
         
            +
                    wh_ratio_clip (float): Maximum aspect ratio for boxes.
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                Returns:
         
     | 
| 158 | 
         
            +
                    Tensor: Boxes with shape (N, 4), where columns represent
         
     | 
| 159 | 
         
            +
                        tl_x, tl_y, br_x, br_y.
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                References:
         
     | 
| 162 | 
         
            +
                    .. [1] https://arxiv.org/abs/1311.2524
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                Example:
         
     | 
| 165 | 
         
            +
                    >>> rois = torch.Tensor([[ 0.,  0.,  1.,  1.],
         
     | 
| 166 | 
         
            +
                    >>>                      [ 0.,  0.,  1.,  1.],
         
     | 
| 167 | 
         
            +
                    >>>                      [ 0.,  0.,  1.,  1.],
         
     | 
| 168 | 
         
            +
                    >>>                      [ 5.,  5.,  5.,  5.]])
         
     | 
| 169 | 
         
            +
                    >>> deltas = torch.Tensor([[  0.,   0.,   0.,   0.],
         
     | 
| 170 | 
         
            +
                    >>>                        [  1.,   1.,   1.,   1.],
         
     | 
| 171 | 
         
            +
                    >>>                        [  0.,   0.,   2.,  -1.],
         
     | 
| 172 | 
         
            +
                    >>>                        [ 0.7, -1.9, -0.5,  0.3]])
         
     | 
| 173 | 
         
            +
                    >>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32))
         
     | 
| 174 | 
         
            +
                    tensor([[0.0000, 0.0000, 1.5000, 1.5000],
         
     | 
| 175 | 
         
            +
                            [0.0000, 0.0000, 5.2183, 5.2183],
         
     | 
| 176 | 
         
            +
                            [0.0000, 0.1321, 7.8891, 0.8679],
         
     | 
| 177 | 
         
            +
                            [5.3967, 2.4251, 6.0033, 3.7749]])
         
     | 
| 178 | 
         
            +
                """
         
     | 
| 179 | 
         
            +
                means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
         
     | 
| 180 | 
         
            +
                stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
         
     | 
| 181 | 
         
            +
                denorm_deltas = deltas * stds + means
         
     | 
| 182 | 
         
            +
                dx = denorm_deltas[:, 0::4]
         
     | 
| 183 | 
         
            +
                dy = denorm_deltas[:, 1::4]
         
     | 
| 184 | 
         
            +
                dw = denorm_deltas[:, 2::4]
         
     | 
| 185 | 
         
            +
                dh = denorm_deltas[:, 3::4]
         
     | 
| 186 | 
         
            +
                max_ratio = np.abs(np.log(wh_ratio_clip))
         
     | 
| 187 | 
         
            +
                dw = dw.clamp(min=-max_ratio, max=max_ratio)
         
     | 
| 188 | 
         
            +
                dh = dh.clamp(min=-max_ratio, max=max_ratio)
         
     | 
| 189 | 
         
            +
                # Compute center of each roi
         
     | 
| 190 | 
         
            +
                px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
         
     | 
| 191 | 
         
            +
                py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
         
     | 
| 192 | 
         
            +
                # Compute width/height of each roi
         
     | 
| 193 | 
         
            +
                pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
         
     | 
| 194 | 
         
            +
                ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
         
     | 
| 195 | 
         
            +
                # Use exp(network energy) to enlarge/shrink each roi
         
     | 
| 196 | 
         
            +
                gw = pw * dw.exp()
         
     | 
| 197 | 
         
            +
                gh = ph * dh.exp()
         
     | 
| 198 | 
         
            +
                # Use network energy to shift the center of each roi
         
     | 
| 199 | 
         
            +
                gx = px + pw * dx
         
     | 
| 200 | 
         
            +
                gy = py + ph * dy
         
     | 
| 201 | 
         
            +
                # Convert center-xy/width/height to top-left, bottom-right
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                # The true legacy box coder should +- 0.5 here.
         
     | 
| 204 | 
         
            +
                # However, current implementation improves the performance when testing
         
     | 
| 205 | 
         
            +
                # the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP)
         
     | 
| 206 | 
         
            +
                x1 = gx - gw * 0.5
         
     | 
| 207 | 
         
            +
                y1 = gy - gh * 0.5
         
     | 
| 208 | 
         
            +
                x2 = gx + gw * 0.5
         
     | 
| 209 | 
         
            +
                y2 = gy + gh * 0.5
         
     | 
| 210 | 
         
            +
                if max_shape is not None:
         
     | 
| 211 | 
         
            +
                    x1 = x1.clamp(min=0, max=max_shape[1] - 1)
         
     | 
| 212 | 
         
            +
                    y1 = y1.clamp(min=0, max=max_shape[0] - 1)
         
     | 
| 213 | 
         
            +
                    x2 = x2.clamp(min=0, max=max_shape[1] - 1)
         
     | 
| 214 | 
         
            +
                    y2 = y2.clamp(min=0, max=max_shape[0] - 1)
         
     | 
| 215 | 
         
            +
                bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
         
     | 
| 216 | 
         
            +
                return bboxes
         
     | 
    	
        mmdet/core/bbox/coder/pseudo_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 3 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 7 | 
         
            +
            class PseudoBBoxCoder(BaseBBoxCoder):
         
     | 
| 8 | 
         
            +
                """Pseudo bounding box coder."""
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                def __init__(self, **kwargs):
         
     | 
| 11 | 
         
            +
                    super(BaseBBoxCoder, self).__init__(**kwargs)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                def encode(self, bboxes, gt_bboxes):
         
     | 
| 14 | 
         
            +
                    """torch.Tensor: return the given ``bboxes``"""
         
     | 
| 15 | 
         
            +
                    return gt_bboxes
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def decode(self, bboxes, pred_bboxes):
         
     | 
| 18 | 
         
            +
                    """torch.Tensor: return the given ``pred_bboxes``"""
         
     | 
| 19 | 
         
            +
                    return pred_bboxes
         
     | 
    	
        mmdet/core/bbox/coder/tblr_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,206 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import mmcv
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 6 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 10 | 
         
            +
            class TBLRBBoxCoder(BaseBBoxCoder):
         
     | 
| 11 | 
         
            +
                """TBLR BBox coder.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Following the practice in `FSAF <https://arxiv.org/abs/1903.00621>`_,
         
     | 
| 14 | 
         
            +
                this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
         
     | 
| 15 | 
         
            +
                right) and decode it back to the original.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                Args:
         
     | 
| 18 | 
         
            +
                    normalizer (list | float): Normalization factor to be
         
     | 
| 19 | 
         
            +
                      divided with when coding the coordinates. If it is a list, it should
         
     | 
| 20 | 
         
            +
                      have length of 4 indicating normalization factor in tblr dims.
         
     | 
| 21 | 
         
            +
                      Otherwise it is a unified float factor for all dims. Default: 4.0
         
     | 
| 22 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 23 | 
         
            +
                        border of the image. Defaults to True.
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def __init__(self, normalizer=4.0, clip_border=True):
         
     | 
| 27 | 
         
            +
                    super(BaseBBoxCoder, self).__init__()
         
     | 
| 28 | 
         
            +
                    self.normalizer = normalizer
         
     | 
| 29 | 
         
            +
                    self.clip_border = clip_border
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def encode(self, bboxes, gt_bboxes):
         
     | 
| 32 | 
         
            +
                    """Get box regression transformation deltas that can be used to
         
     | 
| 33 | 
         
            +
                    transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left,
         
     | 
| 34 | 
         
            +
                    bottom, right) order.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    Args:
         
     | 
| 37 | 
         
            +
                        bboxes (torch.Tensor): source boxes, e.g., object proposals.
         
     | 
| 38 | 
         
            +
                        gt_bboxes (torch.Tensor): target of the transformation, e.g.,
         
     | 
| 39 | 
         
            +
                            ground truth boxes.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    Returns:
         
     | 
| 42 | 
         
            +
                        torch.Tensor: Box transformation deltas
         
     | 
| 43 | 
         
            +
                    """
         
     | 
| 44 | 
         
            +
                    assert bboxes.size(0) == gt_bboxes.size(0)
         
     | 
| 45 | 
         
            +
                    assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
         
     | 
| 46 | 
         
            +
                    encoded_bboxes = bboxes2tblr(
         
     | 
| 47 | 
         
            +
                        bboxes, gt_bboxes, normalizer=self.normalizer)
         
     | 
| 48 | 
         
            +
                    return encoded_bboxes
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def decode(self, bboxes, pred_bboxes, max_shape=None):
         
     | 
| 51 | 
         
            +
                    """Apply transformation `pred_bboxes` to `boxes`.
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    Args:
         
     | 
| 54 | 
         
            +
                        bboxes (torch.Tensor): Basic boxes.Shape (B, N, 4) or (N, 4)
         
     | 
| 55 | 
         
            +
                        pred_bboxes (torch.Tensor): Encoded boxes with shape
         
     | 
| 56 | 
         
            +
                           (B, N, 4) or (N, 4)
         
     | 
| 57 | 
         
            +
                        max_shape (Sequence[int] or torch.Tensor or Sequence[
         
     | 
| 58 | 
         
            +
                           Sequence[int]],optional): Maximum bounds for boxes, specifies
         
     | 
| 59 | 
         
            +
                           (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
         
     | 
| 60 | 
         
            +
                           the max_shape should be a Sequence[Sequence[int]]
         
     | 
| 61 | 
         
            +
                           and the length of max_shape should also be B.
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    Returns:
         
     | 
| 64 | 
         
            +
                        torch.Tensor: Decoded boxes.
         
     | 
| 65 | 
         
            +
                    """
         
     | 
| 66 | 
         
            +
                    decoded_bboxes = tblr2bboxes(
         
     | 
| 67 | 
         
            +
                        bboxes,
         
     | 
| 68 | 
         
            +
                        pred_bboxes,
         
     | 
| 69 | 
         
            +
                        normalizer=self.normalizer,
         
     | 
| 70 | 
         
            +
                        max_shape=max_shape,
         
     | 
| 71 | 
         
            +
                        clip_border=self.clip_border)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    return decoded_bboxes
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 77 | 
         
            +
            def bboxes2tblr(priors, gts, normalizer=4.0, normalize_by_wh=True):
         
     | 
| 78 | 
         
            +
                """Encode ground truth boxes to tblr coordinate.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                It first convert the gt coordinate to tblr format,
         
     | 
| 81 | 
         
            +
                 (top, bottom, left, right), relative to prior box centers.
         
     | 
| 82 | 
         
            +
                 The tblr coordinate may be normalized by the side length of prior bboxes
         
     | 
| 83 | 
         
            +
                 if `normalize_by_wh` is specified as True, and it is then normalized by
         
     | 
| 84 | 
         
            +
                 the `normalizer` factor.
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                Args:
         
     | 
| 87 | 
         
            +
                    priors (Tensor): Prior boxes in point form
         
     | 
| 88 | 
         
            +
                        Shape: (num_proposals,4).
         
     | 
| 89 | 
         
            +
                    gts (Tensor): Coords of ground truth for each prior in point-form
         
     | 
| 90 | 
         
            +
                        Shape: (num_proposals, 4).
         
     | 
| 91 | 
         
            +
                    normalizer (Sequence[float] | float): normalization parameter of
         
     | 
| 92 | 
         
            +
                        encoded boxes. If it is a list, it has to have length = 4.
         
     | 
| 93 | 
         
            +
                        Default: 4.0
         
     | 
| 94 | 
         
            +
                    normalize_by_wh (bool): Whether to normalize tblr coordinate by the
         
     | 
| 95 | 
         
            +
                        side length (wh) of prior bboxes.
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                Return:
         
     | 
| 98 | 
         
            +
                    encoded boxes (Tensor), Shape: (num_proposals, 4)
         
     | 
| 99 | 
         
            +
                """
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                # dist b/t match center and prior's center
         
     | 
| 102 | 
         
            +
                if not isinstance(normalizer, float):
         
     | 
| 103 | 
         
            +
                    normalizer = torch.tensor(normalizer, device=priors.device)
         
     | 
| 104 | 
         
            +
                    assert len(normalizer) == 4, 'Normalizer must have length = 4'
         
     | 
| 105 | 
         
            +
                assert priors.size(0) == gts.size(0)
         
     | 
| 106 | 
         
            +
                prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
         
     | 
| 107 | 
         
            +
                xmin, ymin, xmax, ymax = gts.split(1, dim=1)
         
     | 
| 108 | 
         
            +
                top = prior_centers[:, 1].unsqueeze(1) - ymin
         
     | 
| 109 | 
         
            +
                bottom = ymax - prior_centers[:, 1].unsqueeze(1)
         
     | 
| 110 | 
         
            +
                left = prior_centers[:, 0].unsqueeze(1) - xmin
         
     | 
| 111 | 
         
            +
                right = xmax - prior_centers[:, 0].unsqueeze(1)
         
     | 
| 112 | 
         
            +
                loc = torch.cat((top, bottom, left, right), dim=1)
         
     | 
| 113 | 
         
            +
                if normalize_by_wh:
         
     | 
| 114 | 
         
            +
                    # Normalize tblr by anchor width and height
         
     | 
| 115 | 
         
            +
                    wh = priors[:, 2:4] - priors[:, 0:2]
         
     | 
| 116 | 
         
            +
                    w, h = torch.split(wh, 1, dim=1)
         
     | 
| 117 | 
         
            +
                    loc[:, :2] /= h  # tb is normalized by h
         
     | 
| 118 | 
         
            +
                    loc[:, 2:] /= w  # lr is normalized by w
         
     | 
| 119 | 
         
            +
                # Normalize tblr by the given normalization factor
         
     | 
| 120 | 
         
            +
                return loc / normalizer
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            @mmcv.jit(coderize=True)
         
     | 
| 124 | 
         
            +
            def tblr2bboxes(priors,
         
     | 
| 125 | 
         
            +
                            tblr,
         
     | 
| 126 | 
         
            +
                            normalizer=4.0,
         
     | 
| 127 | 
         
            +
                            normalize_by_wh=True,
         
     | 
| 128 | 
         
            +
                            max_shape=None,
         
     | 
| 129 | 
         
            +
                            clip_border=True):
         
     | 
| 130 | 
         
            +
                """Decode tblr outputs to prediction boxes.
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                The process includes 3 steps: 1) De-normalize tblr coordinates by
         
     | 
| 133 | 
         
            +
                multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
         
     | 
| 134 | 
         
            +
                prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
         
     | 
| 135 | 
         
            +
                tblr (top, bottom, left, right) pair relative to the center of priors back
         
     | 
| 136 | 
         
            +
                to (xmin, ymin, xmax, ymax) coordinate.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                Args:
         
     | 
| 139 | 
         
            +
                    priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
         
     | 
| 140 | 
         
            +
                      Shape: (N,4) or (B, N, 4).
         
     | 
| 141 | 
         
            +
                    tblr (Tensor): Coords of network output in tblr form
         
     | 
| 142 | 
         
            +
                      Shape: (N, 4) or (B, N, 4).
         
     | 
| 143 | 
         
            +
                    normalizer (Sequence[float] | float): Normalization parameter of
         
     | 
| 144 | 
         
            +
                      encoded boxes. By list, it represents the normalization factors at
         
     | 
| 145 | 
         
            +
                      tblr dims. By float, it is the unified normalization factor at all
         
     | 
| 146 | 
         
            +
                      dims. Default: 4.0
         
     | 
| 147 | 
         
            +
                    normalize_by_wh (bool): Whether the tblr coordinates have been
         
     | 
| 148 | 
         
            +
                      normalized by the side length (wh) of prior bboxes.
         
     | 
| 149 | 
         
            +
                    max_shape (Sequence[int] or torch.Tensor or Sequence[
         
     | 
| 150 | 
         
            +
                        Sequence[int]],optional): Maximum bounds for boxes, specifies
         
     | 
| 151 | 
         
            +
                        (H, W, C) or (H, W). If priors shape is (B, N, 4), then
         
     | 
| 152 | 
         
            +
                        the max_shape should be a Sequence[Sequence[int]]
         
     | 
| 153 | 
         
            +
                        and the length of max_shape should also be B.
         
     | 
| 154 | 
         
            +
                    clip_border (bool, optional): Whether clip the objects outside the
         
     | 
| 155 | 
         
            +
                        border of the image. Defaults to True.
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                Return:
         
     | 
| 158 | 
         
            +
                    encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
         
     | 
| 159 | 
         
            +
                """
         
     | 
| 160 | 
         
            +
                if not isinstance(normalizer, float):
         
     | 
| 161 | 
         
            +
                    normalizer = torch.tensor(normalizer, device=priors.device)
         
     | 
| 162 | 
         
            +
                    assert len(normalizer) == 4, 'Normalizer must have length = 4'
         
     | 
| 163 | 
         
            +
                assert priors.size(0) == tblr.size(0)
         
     | 
| 164 | 
         
            +
                if priors.ndim == 3:
         
     | 
| 165 | 
         
            +
                    assert priors.size(1) == tblr.size(1)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                loc_decode = tblr * normalizer
         
     | 
| 168 | 
         
            +
                prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
         
     | 
| 169 | 
         
            +
                if normalize_by_wh:
         
     | 
| 170 | 
         
            +
                    wh = priors[..., 2:4] - priors[..., 0:2]
         
     | 
| 171 | 
         
            +
                    w, h = torch.split(wh, 1, dim=-1)
         
     | 
| 172 | 
         
            +
                    # Inplace operation with slice would failed for exporting to ONNX
         
     | 
| 173 | 
         
            +
                    th = h * loc_decode[..., :2]  # tb
         
     | 
| 174 | 
         
            +
                    tw = w * loc_decode[..., 2:]  # lr
         
     | 
| 175 | 
         
            +
                    loc_decode = torch.cat([th, tw], dim=-1)
         
     | 
| 176 | 
         
            +
                # Cannot be exported using onnx when loc_decode.split(1, dim=-1)
         
     | 
| 177 | 
         
            +
                top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
         
     | 
| 178 | 
         
            +
                xmin = prior_centers[..., 0].unsqueeze(-1) - left
         
     | 
| 179 | 
         
            +
                xmax = prior_centers[..., 0].unsqueeze(-1) + right
         
     | 
| 180 | 
         
            +
                ymin = prior_centers[..., 1].unsqueeze(-1) - top
         
     | 
| 181 | 
         
            +
                ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                if clip_border and max_shape is not None:
         
     | 
| 186 | 
         
            +
                    # clip bboxes with dynamic `min` and `max` for onnx
         
     | 
| 187 | 
         
            +
                    if torch.onnx.is_in_onnx_export():
         
     | 
| 188 | 
         
            +
                        from mmdet.core.export import dynamic_clip_for_onnx
         
     | 
| 189 | 
         
            +
                        xmin, ymin, xmax, ymax = dynamic_clip_for_onnx(
         
     | 
| 190 | 
         
            +
                            xmin, ymin, xmax, ymax, max_shape)
         
     | 
| 191 | 
         
            +
                        bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1)
         
     | 
| 192 | 
         
            +
                        return bboxes
         
     | 
| 193 | 
         
            +
                    if not isinstance(max_shape, torch.Tensor):
         
     | 
| 194 | 
         
            +
                        max_shape = priors.new_tensor(max_shape)
         
     | 
| 195 | 
         
            +
                    max_shape = max_shape[..., :2].type_as(priors)
         
     | 
| 196 | 
         
            +
                    if max_shape.ndim == 2:
         
     | 
| 197 | 
         
            +
                        assert bboxes.ndim == 3
         
     | 
| 198 | 
         
            +
                        assert max_shape.size(0) == bboxes.size(0)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    min_xy = priors.new_tensor(0)
         
     | 
| 201 | 
         
            +
                    max_xy = torch.cat([max_shape, max_shape],
         
     | 
| 202 | 
         
            +
                                       dim=-1).flip(-1).unsqueeze(-2)
         
     | 
| 203 | 
         
            +
                    bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
         
     | 
| 204 | 
         
            +
                    bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                return bboxes
         
     | 
    	
        mmdet/core/bbox/coder/yolo_bbox_coder.py
    ADDED
    
    | 
         @@ -0,0 +1,83 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import mmcv
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from ..builder import BBOX_CODERS
         
     | 
| 6 | 
         
            +
            from .base_bbox_coder import BaseBBoxCoder
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @BBOX_CODERS.register_module()
         
     | 
| 10 | 
         
            +
            class YOLOBBoxCoder(BaseBBoxCoder):
         
     | 
| 11 | 
         
            +
                """YOLO BBox coder.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Following `YOLO <https://arxiv.org/abs/1506.02640>`_, this coder divide
         
     | 
| 14 | 
         
            +
                image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh).
         
     | 
| 15 | 
         
            +
                cx, cy in [0., 1.], denotes relative center position w.r.t the center of
         
     | 
| 16 | 
         
            +
                bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                Args:
         
     | 
| 19 | 
         
            +
                    eps (float): Min value of cx, cy when encoding.
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                def __init__(self, eps=1e-6):
         
     | 
| 23 | 
         
            +
                    super(BaseBBoxCoder, self).__init__()
         
     | 
| 24 | 
         
            +
                    self.eps = eps
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                @mmcv.jit(coderize=True)
         
     | 
| 27 | 
         
            +
                def encode(self, bboxes, gt_bboxes, stride):
         
     | 
| 28 | 
         
            +
                    """Get box regression transformation deltas that can be used to
         
     | 
| 29 | 
         
            +
                    transform the ``bboxes`` into the ``gt_bboxes``.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    Args:
         
     | 
| 32 | 
         
            +
                        bboxes (torch.Tensor): Source boxes, e.g., anchors.
         
     | 
| 33 | 
         
            +
                        gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
         
     | 
| 34 | 
         
            +
                            ground-truth boxes.
         
     | 
| 35 | 
         
            +
                        stride (torch.Tensor | int): Stride of bboxes.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    Returns:
         
     | 
| 38 | 
         
            +
                        torch.Tensor: Box transformation deltas
         
     | 
| 39 | 
         
            +
                    """
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    assert bboxes.size(0) == gt_bboxes.size(0)
         
     | 
| 42 | 
         
            +
                    assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
         
     | 
| 43 | 
         
            +
                    x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5
         
     | 
| 44 | 
         
            +
                    y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5
         
     | 
| 45 | 
         
            +
                    w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0]
         
     | 
| 46 | 
         
            +
                    h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1]
         
     | 
| 47 | 
         
            +
                    x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
         
     | 
| 48 | 
         
            +
                    y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
         
     | 
| 49 | 
         
            +
                    w = bboxes[..., 2] - bboxes[..., 0]
         
     | 
| 50 | 
         
            +
                    h = bboxes[..., 3] - bboxes[..., 1]
         
     | 
| 51 | 
         
            +
                    w_target = torch.log((w_gt / w).clamp(min=self.eps))
         
     | 
| 52 | 
         
            +
                    h_target = torch.log((h_gt / h).clamp(min=self.eps))
         
     | 
| 53 | 
         
            +
                    x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp(
         
     | 
| 54 | 
         
            +
                        self.eps, 1 - self.eps)
         
     | 
| 55 | 
         
            +
                    y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp(
         
     | 
| 56 | 
         
            +
                        self.eps, 1 - self.eps)
         
     | 
| 57 | 
         
            +
                    encoded_bboxes = torch.stack(
         
     | 
| 58 | 
         
            +
                        [x_center_target, y_center_target, w_target, h_target], dim=-1)
         
     | 
| 59 | 
         
            +
                    return encoded_bboxes
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                @mmcv.jit(coderize=True)
         
     | 
| 62 | 
         
            +
                def decode(self, bboxes, pred_bboxes, stride):
         
     | 
| 63 | 
         
            +
                    """Apply transformation `pred_bboxes` to `boxes`.
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    Args:
         
     | 
| 66 | 
         
            +
                        boxes (torch.Tensor): Basic boxes, e.g. anchors.
         
     | 
| 67 | 
         
            +
                        pred_bboxes (torch.Tensor): Encoded boxes with shape
         
     | 
| 68 | 
         
            +
                        stride (torch.Tensor | int): Strides of bboxes.
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    Returns:
         
     | 
| 71 | 
         
            +
                        torch.Tensor: Decoded boxes.
         
     | 
| 72 | 
         
            +
                    """
         
     | 
| 73 | 
         
            +
                    assert pred_bboxes.size(-1) == bboxes.size(-1) == 4
         
     | 
| 74 | 
         
            +
                    xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + (
         
     | 
| 75 | 
         
            +
                        pred_bboxes[..., :2] - 0.5) * stride
         
     | 
| 76 | 
         
            +
                    whs = (bboxes[..., 2:] -
         
     | 
| 77 | 
         
            +
                           bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp()
         
     | 
| 78 | 
         
            +
                    decoded_bboxes = torch.stack(
         
     | 
| 79 | 
         
            +
                        (xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] -
         
     | 
| 80 | 
         
            +
                         whs[..., 1], xy_centers[..., 0] + whs[..., 0],
         
     | 
| 81 | 
         
            +
                         xy_centers[..., 1] + whs[..., 1]),
         
     | 
| 82 | 
         
            +
                        dim=-1)
         
     | 
| 83 | 
         
            +
                    return decoded_bboxes
         
     | 
    	
        mmdet/core/bbox/demodata.py
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from mmdet.utils.util_random import ensure_rng
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            def random_boxes(num=1, scale=1, rng=None):
         
     | 
| 9 | 
         
            +
                """Simple version of ``kwimage.Boxes.random``
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                Returns:
         
     | 
| 12 | 
         
            +
                    Tensor: shape (n, 4) in x1, y1, x2, y2 format.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                References:
         
     | 
| 15 | 
         
            +
                    https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                Example:
         
     | 
| 18 | 
         
            +
                    >>> num = 3
         
     | 
| 19 | 
         
            +
                    >>> scale = 512
         
     | 
| 20 | 
         
            +
                    >>> rng = 0
         
     | 
| 21 | 
         
            +
                    >>> boxes = random_boxes(num, scale, rng)
         
     | 
| 22 | 
         
            +
                    >>> print(boxes)
         
     | 
| 23 | 
         
            +
                    tensor([[280.9925, 278.9802, 308.6148, 366.1769],
         
     | 
| 24 | 
         
            +
                            [216.9113, 330.6978, 224.0446, 456.5878],
         
     | 
| 25 | 
         
            +
                            [405.3632, 196.3221, 493.3953, 270.7942]])
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
                rng = ensure_rng(rng)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                tlbr = rng.rand(num, 4).astype(np.float32)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
         
     | 
| 32 | 
         
            +
                tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
         
     | 
| 33 | 
         
            +
                br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
         
     | 
| 34 | 
         
            +
                br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                tlbr[:, 0] = tl_x * scale
         
     | 
| 37 | 
         
            +
                tlbr[:, 1] = tl_y * scale
         
     | 
| 38 | 
         
            +
                tlbr[:, 2] = br_x * scale
         
     | 
| 39 | 
         
            +
                tlbr[:, 3] = br_y * scale
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                boxes = torch.from_numpy(tlbr)
         
     | 
| 42 | 
         
            +
                return boxes
         
     |