Spaces:
Runtime error
Runtime error
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .gitignore +131 -0
- README.md +67 -13
- app.py +165 -0
- deep_sort/configs/deep_sort.yaml +10 -0
- deep_sort/deep_sort/README.md +3 -0
- deep_sort/deep_sort/__init__.py +21 -0
- deep_sort/deep_sort/__pycache__/__init__.cpython-310.pyc +0 -0
- deep_sort/deep_sort/__pycache__/deep_sort.cpython-310.pyc +0 -0
- deep_sort/deep_sort/deep/__init__.py +0 -0
- deep_sort/deep_sort/deep/__pycache__/__init__.cpython-310.pyc +0 -0
- deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc +0 -0
- deep_sort/deep_sort/deep/__pycache__/model.cpython-310.pyc +0 -0
- deep_sort/deep_sort/deep/checkpoint/ckpt.t7 +3 -0
- deep_sort/deep_sort/deep/evaluate.py +15 -0
- deep_sort/deep_sort/deep/feature_extractor.py +65 -0
- deep_sort/deep_sort/deep/model.py +105 -0
- deep_sort/deep_sort/deep/original_model.py +106 -0
- deep_sort/deep_sort/deep/prepare_car.py +129 -0
- deep_sort/deep_sort/deep/prepare_person.py +108 -0
- deep_sort/deep_sort/deep/test.py +77 -0
- deep_sort/deep_sort/deep/train.jpg +0 -0
- deep_sort/deep_sort/deep/train.py +192 -0
- deep_sort/deep_sort/deep_sort.py +125 -0
- deep_sort/deep_sort/sort/__init__.py +0 -0
- deep_sort/deep_sort/sort/__pycache__/__init__.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/detection.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/track.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/__pycache__/tracker.cpython-310.pyc +0 -0
- deep_sort/deep_sort/sort/detection.py +49 -0
- deep_sort/deep_sort/sort/iou_matching.py +84 -0
- deep_sort/deep_sort/sort/kalman_filter.py +286 -0
- deep_sort/deep_sort/sort/linear_assignment.py +240 -0
- deep_sort/deep_sort/sort/nn_matching.py +207 -0
- deep_sort/deep_sort/sort/preprocessing.py +73 -0
- deep_sort/deep_sort/sort/track.py +199 -0
- deep_sort/deep_sort/sort/tracker.py +168 -0
- deep_sort/utils/__init__.py +0 -0
- deep_sort/utils/asserts.py +13 -0
- deep_sort/utils/draw.py +36 -0
- deep_sort/utils/evaluation.py +103 -0
- deep_sort/utils/io.py +133 -0
- deep_sort/utils/json_logger.py +383 -0
- deep_sort/utils/log.py +17 -0
- deep_sort/utils/parser.py +38 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
demo.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
test.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
openh264-1.8.0-win64.dll
|
README.md
CHANGED
@@ -1,13 +1,67 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1> yolov8-deepsort-tracking </h1>
|
3 |
+
</div>
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
opencv+yolov8+deepsort的行人检测与跟踪。当然,也可以识别车辆等其他类别。
|
8 |
+
|
9 |
+
- 2023/10/17更新:简化代码,删除不必要的依赖
|
10 |
+
|
11 |
+
- 2023/7/4更新:加入了一个基于Gradio的WebUI界面
|
12 |
+
|
13 |
+
## 安装
|
14 |
+
环境:Python>=3.8
|
15 |
+
|
16 |
+
本项目需要pytorch,建议手动在[pytorch官网](https://pytorch.org/get-started/locally/)根据自己的平台和CUDA环境安装对应的版本。
|
17 |
+
|
18 |
+
pytorch的详细安装教程可以参照[Conda Quickstart Guide for Ultralytics](https://docs.ultralytics.com/guides/conda-quickstart/)
|
19 |
+
|
20 |
+
安装完pytorch后,需要通过以下命令来安装其他依赖:
|
21 |
+
|
22 |
+
```shell
|
23 |
+
$ pip install -r requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
|
27 |
+
## 配置(非WebUI)
|
28 |
+
|
29 |
+
在main.py中修改以下代码,将输入视频路径换成你要处理的视频的路径:
|
30 |
+
|
31 |
+
```python
|
32 |
+
input_video_path = "test.mp4"
|
33 |
+
```
|
34 |
+
|
35 |
+
模型默认使用Ultralytics官方的YOLOv8n模型:
|
36 |
+
|
37 |
+
```python
|
38 |
+
model = "yolov8n.pt"
|
39 |
+
```
|
40 |
+
|
41 |
+
第一次使用会自动从官网下载模型,如果网速过慢,可以在[ultralytics的官方文档](https://docs.ultralytics.com/tasks/detect/)下载模型,然后将模型文件拷贝到程序所在目录下。
|
42 |
+
|
43 |
+
## 运行(非WebUI)
|
44 |
+
|
45 |
+
运行main.py
|
46 |
+
运行完成后,终端会显示输出视频所在的路径。
|
47 |
+
|
48 |
+
## WebUI界面的配置和运行
|
49 |
+
|
50 |
+
**请先确保已经安装完成上面的依赖**
|
51 |
+
|
52 |
+
安装Gradio库:
|
53 |
+
|
54 |
+
```shell
|
55 |
+
$ pip install gradio
|
56 |
+
```
|
57 |
+
|
58 |
+
运行app.py,如果控制台出现以下消息代表成功运行:
|
59 |
+
```shell
|
60 |
+
Running on local URL: http://127.0.0.1:6006
|
61 |
+
To create a public link, set `share=True` in `launch()`
|
62 |
+
```
|
63 |
+
|
64 |
+
浏览器打开该URL即可使用WebUI界面
|
65 |
+
|
66 |
+

|
67 |
+
|
app.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ultralytics import YOLO
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
import deep_sort.deep_sort.deep_sort as ds
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# YoloV8官方模型,从左往右由小到大,第一次使用会自动下载
|
11 |
+
model_list = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt", "yolov8l.pt", "yolov8x.pt"]
|
12 |
+
|
13 |
+
def putTextWithBackground(
|
14 |
+
img,
|
15 |
+
text,
|
16 |
+
origin,
|
17 |
+
font=cv2.FONT_HERSHEY_SIMPLEX,
|
18 |
+
font_scale=1,
|
19 |
+
text_color=(255, 255, 255),
|
20 |
+
bg_color=(0, 0, 0),
|
21 |
+
thickness=1,
|
22 |
+
):
|
23 |
+
"""绘制带有背景的文本。
|
24 |
+
|
25 |
+
:param img: 输入图像。
|
26 |
+
:param text: 要绘制的文本。
|
27 |
+
:param origin: 文本的左上角坐标。
|
28 |
+
:param font: 字体类型。
|
29 |
+
:param font_scale: 字体大小。
|
30 |
+
:param text_color: 文本的颜色。
|
31 |
+
:param bg_color: 背景的颜色。
|
32 |
+
:param thickness: 文本的线条厚度。
|
33 |
+
"""
|
34 |
+
# 计算文本的尺寸
|
35 |
+
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
|
36 |
+
|
37 |
+
# 绘制背景矩形
|
38 |
+
bottom_left = origin
|
39 |
+
top_right = (origin[0] + text_width, origin[1] - text_height - 5) # 减去5以留出一些边距
|
40 |
+
cv2.rectangle(img, bottom_left, top_right, bg_color, -1)
|
41 |
+
|
42 |
+
# 在矩形上绘制文本
|
43 |
+
text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
|
44 |
+
cv2.putText(
|
45 |
+
img,
|
46 |
+
text,
|
47 |
+
text_origin,
|
48 |
+
font,
|
49 |
+
font_scale,
|
50 |
+
text_color,
|
51 |
+
thickness,
|
52 |
+
lineType=cv2.LINE_AA,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
# 视频处理
|
57 |
+
def processVideo(inputPath, model):
|
58 |
+
"""处理视频,检测并跟踪行人。
|
59 |
+
|
60 |
+
:param inputPath: 视频文件路径
|
61 |
+
:return: 输出视频的路径
|
62 |
+
"""
|
63 |
+
tracker = ds.DeepSort(
|
64 |
+
"deep_sort/deep_sort/deep/checkpoint/ckpt.t7"
|
65 |
+
) # 加载deepsort权重文件
|
66 |
+
model = YOLO(model) # 加载YOLO模型文件
|
67 |
+
|
68 |
+
# 读取视频文件
|
69 |
+
cap = cv2.VideoCapture(inputPath)
|
70 |
+
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
|
71 |
+
size = (
|
72 |
+
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
73 |
+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
74 |
+
) # 获取视频的大小
|
75 |
+
output_video = cv2.VideoWriter() # 初始化视频写入
|
76 |
+
outputPath = tempfile.mkdtemp() # 创建输出视频的临时文件夹的路径
|
77 |
+
|
78 |
+
# 输出格式为XVID格式的avi文件
|
79 |
+
# 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
|
80 |
+
# 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
|
81 |
+
# 下载完成后将dll文件放在当前文件夹内
|
82 |
+
output_type = "avi"
|
83 |
+
if output_type == "avi":
|
84 |
+
fourcc = cv2.VideoWriter_fourcc(*"XVID")
|
85 |
+
video_save_path = Path(outputPath) / "output.avi" # 创建输出视频路径
|
86 |
+
if output_type == "mp4": # 浏览器只支持播放h264编码的mp4视频文件
|
87 |
+
fourcc = cv2.VideoWriter_fourcc(*"h264")
|
88 |
+
video_save_path = Path(outputPath) / "output.mp4"
|
89 |
+
|
90 |
+
output_video.open(video_save_path.as_posix(), fourcc, fps, size, True)
|
91 |
+
# 对每一帧图片进行读取和处理
|
92 |
+
while True:
|
93 |
+
success, frame = cap.read()
|
94 |
+
if not (success):
|
95 |
+
break
|
96 |
+
|
97 |
+
# 获取每一帧的目标检测推理结果
|
98 |
+
results = model(frame, stream=True)
|
99 |
+
|
100 |
+
detections = [] # 存放bounding box结果
|
101 |
+
confarray = [] # 存放每个检测结果的置信度
|
102 |
+
|
103 |
+
# 读取目标检测推理结果
|
104 |
+
# 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
|
105 |
+
for r in results:
|
106 |
+
boxes = r.boxes
|
107 |
+
for box in boxes:
|
108 |
+
x1, y1, x2, y2 = map(int, box.xywh[0]) # 提取矩形框左上和右下的点,并将tensor类型转为整型
|
109 |
+
conf = round(float(box.conf[0]), 2) # 对conf四舍五入到2位小数
|
110 |
+
cls = int(box.cls[0]) # 获取物体类别标签
|
111 |
+
|
112 |
+
if cls == detect_class:
|
113 |
+
detections.append([x1, y1, x2, y2])
|
114 |
+
confarray.append(conf)
|
115 |
+
|
116 |
+
# 使用deepsort进行跟踪
|
117 |
+
resultsTracker = tracker.update(np.array(detections), confarray, frame)
|
118 |
+
for x1, y1, x2, y2, Id in resultsTracker:
|
119 |
+
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
120 |
+
|
121 |
+
# 绘制bounding box
|
122 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
|
123 |
+
putTextWithBackground(
|
124 |
+
frame,
|
125 |
+
str(int(Id)),
|
126 |
+
(max(-10, x1), max(40, y1)),
|
127 |
+
font_scale=1.5,
|
128 |
+
text_color=(255, 255, 255),
|
129 |
+
bg_color=(255, 0, 255),
|
130 |
+
)
|
131 |
+
|
132 |
+
output_video.write(frame) # 将处理后的图像写入视频
|
133 |
+
output_video.release() # 释放
|
134 |
+
cap.release() # 释放
|
135 |
+
print(f"output dir is: {video_save_path.as_posix()}")
|
136 |
+
return video_save_path.as_posix(), video_save_path.as_posix() # Gradio的视频控件实际读取的是文件路径
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
# 需要跟踪的物体类别
|
141 |
+
detect_class = 0
|
142 |
+
|
143 |
+
# Gradio参考文档:https://www.gradio.app/guides/blocks-and-event-listeners
|
144 |
+
with gr.Blocks() as demo:
|
145 |
+
with gr.Tab("Tracking"):
|
146 |
+
gr.Markdown(
|
147 |
+
"""
|
148 |
+
# YoloV8 + deepsort
|
149 |
+
基于opencv + YoloV8 + deepsort
|
150 |
+
"""
|
151 |
+
)
|
152 |
+
with gr.Row():
|
153 |
+
with gr.Column():
|
154 |
+
input_video = gr.Video(label="Input video")
|
155 |
+
model = gr.Dropdown(model_list, value="yolov8n.pt", label="Model")
|
156 |
+
with gr.Column():
|
157 |
+
output = gr.Video()
|
158 |
+
output_path = gr.Textbox(label="Output path")
|
159 |
+
button = gr.Button("Process")
|
160 |
+
|
161 |
+
button.click(
|
162 |
+
processVideo, inputs=[input_video, model], outputs=[output, output_path]
|
163 |
+
)
|
164 |
+
|
165 |
+
demo.launch(server_port=6006)
|
deep_sort/configs/deep_sort.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEEPSORT:
|
2 |
+
REID_CKPT: "deep_sort/deep_sort/deep/checkpoint/ckpt.t7"
|
3 |
+
MAX_DIST: 0.2
|
4 |
+
MIN_CONFIDENCE: 0.3
|
5 |
+
NMS_MAX_OVERLAP: 0.5
|
6 |
+
MAX_IOU_DISTANCE: 0.7
|
7 |
+
MAX_AGE: 70
|
8 |
+
N_INIT: 3
|
9 |
+
NN_BUDGET: 100
|
10 |
+
|
deep_sort/deep_sort/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Deep Sort
|
2 |
+
|
3 |
+
This is the implemention of deep sort with pytorch.
|
deep_sort/deep_sort/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .deep_sort import DeepSort
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ['DeepSort', 'build_tracker']
|
5 |
+
|
6 |
+
|
7 |
+
def build_tracker(cfg, use_cuda):
|
8 |
+
return DeepSort(cfg.DEEPSORT.REID_CKPT,
|
9 |
+
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
|
10 |
+
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
|
11 |
+
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
deep_sort/deep_sort/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (619 Bytes). View file
|
|
deep_sort/deep_sort/__pycache__/deep_sort.cpython-310.pyc
ADDED
Binary file (4.16 kB). View file
|
|
deep_sort/deep_sort/deep/__init__.py
ADDED
File without changes
|
deep_sort/deep_sort/deep/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (168 Bytes). View file
|
|
deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
deep_sort/deep_sort/deep/__pycache__/model.cpython-310.pyc
ADDED
Binary file (2.82 kB). View file
|
|
deep_sort/deep_sort/deep/checkpoint/ckpt.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22628596f112dc7eb1fe7adfbfaf95bbc6ce8eb024205beafdc705232a646c29
|
3 |
+
size 46061055
|
deep_sort/deep_sort/deep/evaluate.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
features = torch.load("features.pth")
|
4 |
+
qf = features["qf"]
|
5 |
+
ql = features["ql"]
|
6 |
+
gf = features["gf"]
|
7 |
+
gl = features["gl"]
|
8 |
+
|
9 |
+
scores = qf.mm(gf.t())
|
10 |
+
res = scores.topk(5, dim=1)[1][:,0]
|
11 |
+
top1correct = gl[res].eq(ql).sum().item()
|
12 |
+
|
13 |
+
print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))
|
14 |
+
|
15 |
+
|
deep_sort/deep_sort/deep/feature_extractor.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import logging
|
6 |
+
|
7 |
+
from .model import Net
|
8 |
+
|
9 |
+
'''
|
10 |
+
特征提取器:
|
11 |
+
提取对应bounding box中的特征, 得到一个固定维度的embedding作为该bounding box的代表,
|
12 |
+
供计算相似度时使用。
|
13 |
+
|
14 |
+
模型训练是按照传统ReID的方法进行,使用Extractor类的时候输入为一个list的图片,得到图片对应的特征。
|
15 |
+
'''
|
16 |
+
|
17 |
+
class Extractor(object):
|
18 |
+
def __init__(self, model_path, use_cuda=True):
|
19 |
+
self.net = Net(reid=True)
|
20 |
+
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
21 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
|
22 |
+
self.net.load_state_dict(state_dict)
|
23 |
+
logger = logging.getLogger("root.tracker")
|
24 |
+
logger.info("Loading weights from {}... Done!".format(model_path))
|
25 |
+
self.net.to(self.device)
|
26 |
+
self.size = (64, 128)
|
27 |
+
self.norm = transforms.Compose([
|
28 |
+
# RGB图片数据范围是[0-255],需要先经过ToTensor除以255归一化到[0,1]之后,
|
29 |
+
# 再通过Normalize计算(x - mean)/std后,将数据归一化到[-1,1]。
|
30 |
+
transforms.ToTensor(),
|
31 |
+
# mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]是从imagenet训练集中算出来的
|
32 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
33 |
+
])
|
34 |
+
|
35 |
+
def _preprocess(self, im_crops):
|
36 |
+
"""
|
37 |
+
TODO:
|
38 |
+
1. to float with scale from 0 to 1
|
39 |
+
2. resize to (64, 128) as Market1501 dataset did
|
40 |
+
3. concatenate to a numpy array
|
41 |
+
3. to torch Tensor
|
42 |
+
4. normalize
|
43 |
+
"""
|
44 |
+
def _resize(im, size):
|
45 |
+
return cv2.resize(im.astype(np.float32)/255., size)
|
46 |
+
|
47 |
+
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
48 |
+
return im_batch
|
49 |
+
|
50 |
+
# __call__()是一个非常特殊的实例方法。该方法的功能类似于在类中重载 () 运算符,
|
51 |
+
# 使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。
|
52 |
+
def __call__(self, im_crops):
|
53 |
+
im_batch = self._preprocess(im_crops)
|
54 |
+
with torch.no_grad():
|
55 |
+
im_batch = im_batch.to(self.device)
|
56 |
+
features = self.net(im_batch)
|
57 |
+
return features.cpu().numpy()
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
img = cv2.imread("demo.jpg")[:,:,(2,1,0)]
|
62 |
+
extr = Extractor("checkpoint/ckpt.t7")
|
63 |
+
feature = extr(img)
|
64 |
+
print(feature.shape)
|
65 |
+
|
deep_sort/deep_sort/deep/model.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class BasicBlock(nn.Module):
|
6 |
+
def __init__(self, c_in, c_out,is_downsample=False):
|
7 |
+
super(BasicBlock,self).__init__()
|
8 |
+
self.is_downsample = is_downsample
|
9 |
+
if is_downsample:
|
10 |
+
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
|
11 |
+
else:
|
12 |
+
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
|
13 |
+
self.bn1 = nn.BatchNorm2d(c_out)
|
14 |
+
self.relu = nn.ReLU(True)
|
15 |
+
self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False)
|
16 |
+
self.bn2 = nn.BatchNorm2d(c_out)
|
17 |
+
if is_downsample:
|
18 |
+
self.downsample = nn.Sequential(
|
19 |
+
nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
|
20 |
+
nn.BatchNorm2d(c_out)
|
21 |
+
)
|
22 |
+
elif c_in != c_out:
|
23 |
+
self.downsample = nn.Sequential(
|
24 |
+
nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
|
25 |
+
nn.BatchNorm2d(c_out)
|
26 |
+
)
|
27 |
+
self.is_downsample = True
|
28 |
+
|
29 |
+
def forward(self,x):
|
30 |
+
y = self.conv1(x)
|
31 |
+
y = self.bn1(y)
|
32 |
+
y = self.relu(y)
|
33 |
+
y = self.conv2(y)
|
34 |
+
y = self.bn2(y)
|
35 |
+
if self.is_downsample:
|
36 |
+
x = self.downsample(x)
|
37 |
+
return F.relu(x.add(y),True)
|
38 |
+
|
39 |
+
def make_layers(c_in,c_out,repeat_times, is_downsample=False):
|
40 |
+
blocks = []
|
41 |
+
for i in range(repeat_times):
|
42 |
+
if i ==0:
|
43 |
+
blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),]
|
44 |
+
else:
|
45 |
+
blocks += [BasicBlock(c_out,c_out),]
|
46 |
+
return nn.Sequential(*blocks)
|
47 |
+
|
48 |
+
class Net(nn.Module):
|
49 |
+
def __init__(self, num_classes=751, reid=False):
|
50 |
+
super(Net,self).__init__()
|
51 |
+
# 3 128 64
|
52 |
+
self.conv = nn.Sequential(
|
53 |
+
nn.Conv2d(3,64,3,stride=1,padding=1),
|
54 |
+
nn.BatchNorm2d(64),
|
55 |
+
nn.ReLU(inplace=True),
|
56 |
+
# nn.Conv2d(32,32,3,stride=1,padding=1),
|
57 |
+
# nn.BatchNorm2d(32),
|
58 |
+
# nn.ReLU(inplace=True),
|
59 |
+
nn.MaxPool2d(3,2,padding=1),
|
60 |
+
)
|
61 |
+
# 32 64 32
|
62 |
+
self.layer1 = make_layers(64,64,2,False)
|
63 |
+
# 32 64 32
|
64 |
+
self.layer2 = make_layers(64,128,2,True)
|
65 |
+
# 64 32 16
|
66 |
+
self.layer3 = make_layers(128,256,2,True)
|
67 |
+
# 128 16 8
|
68 |
+
self.layer4 = make_layers(256,512,2,True)
|
69 |
+
# 256 8 4
|
70 |
+
self.avgpool = nn.AvgPool2d((8,4),1)
|
71 |
+
# 256 1 1
|
72 |
+
self.reid = reid
|
73 |
+
|
74 |
+
self.classifier = nn.Sequential(
|
75 |
+
nn.Linear(512, 256),
|
76 |
+
nn.BatchNorm1d(256),
|
77 |
+
nn.ReLU(inplace=True),
|
78 |
+
nn.Dropout(),
|
79 |
+
nn.Linear(256, num_classes),
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
x = self.conv(x)
|
84 |
+
x = self.layer1(x)
|
85 |
+
x = self.layer2(x)
|
86 |
+
x = self.layer3(x)
|
87 |
+
x = self.layer4(x)
|
88 |
+
x = self.avgpool(x)
|
89 |
+
x = x.view(x.size(0),-1)
|
90 |
+
# B x 128
|
91 |
+
if self.reid:
|
92 |
+
x = x.div(x.norm(p=2,dim=1,keepdim=True))
|
93 |
+
return x
|
94 |
+
# classifier
|
95 |
+
x = self.classifier(x)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
net = Net()
|
101 |
+
x = torch.randn(4,3,128,64)
|
102 |
+
y = net(x)
|
103 |
+
import ipdb; ipdb.set_trace()
|
104 |
+
|
105 |
+
|
deep_sort/deep_sort/deep/original_model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class BasicBlock(nn.Module):
|
6 |
+
def __init__(self, c_in, c_out,is_downsample=False):
|
7 |
+
super(BasicBlock,self).__init__()
|
8 |
+
self.is_downsample = is_downsample
|
9 |
+
if is_downsample:
|
10 |
+
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
|
11 |
+
else:
|
12 |
+
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
|
13 |
+
self.bn1 = nn.BatchNorm2d(c_out)
|
14 |
+
self.relu = nn.ReLU(True)
|
15 |
+
self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False)
|
16 |
+
self.bn2 = nn.BatchNorm2d(c_out)
|
17 |
+
if is_downsample:
|
18 |
+
self.downsample = nn.Sequential(
|
19 |
+
nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
|
20 |
+
nn.BatchNorm2d(c_out)
|
21 |
+
)
|
22 |
+
elif c_in != c_out:
|
23 |
+
self.downsample = nn.Sequential(
|
24 |
+
nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
|
25 |
+
nn.BatchNorm2d(c_out)
|
26 |
+
)
|
27 |
+
self.is_downsample = True
|
28 |
+
|
29 |
+
def forward(self,x):
|
30 |
+
y = self.conv1(x)
|
31 |
+
y = self.bn1(y)
|
32 |
+
y = self.relu(y)
|
33 |
+
y = self.conv2(y)
|
34 |
+
y = self.bn2(y)
|
35 |
+
if self.is_downsample:
|
36 |
+
x = self.downsample(x)
|
37 |
+
return F.relu(x.add(y),True)
|
38 |
+
|
39 |
+
def make_layers(c_in,c_out,repeat_times, is_downsample=False):
|
40 |
+
blocks = []
|
41 |
+
for i in range(repeat_times):
|
42 |
+
if i ==0:
|
43 |
+
blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),]
|
44 |
+
else:
|
45 |
+
blocks += [BasicBlock(c_out,c_out),]
|
46 |
+
return nn.Sequential(*blocks)
|
47 |
+
|
48 |
+
class Net(nn.Module):
|
49 |
+
def __init__(self, num_classes=625 ,reid=False):
|
50 |
+
super(Net,self).__init__()
|
51 |
+
# 3 128 64
|
52 |
+
self.conv = nn.Sequential(
|
53 |
+
nn.Conv2d(3,32,3,stride=1,padding=1),
|
54 |
+
nn.BatchNorm2d(32),
|
55 |
+
nn.ELU(inplace=True),
|
56 |
+
nn.Conv2d(32,32,3,stride=1,padding=1),
|
57 |
+
nn.BatchNorm2d(32),
|
58 |
+
nn.ELU(inplace=True),
|
59 |
+
nn.MaxPool2d(3,2,padding=1),
|
60 |
+
)
|
61 |
+
# 32 64 32
|
62 |
+
self.layer1 = make_layers(32,32,2,False)
|
63 |
+
# 32 64 32
|
64 |
+
self.layer2 = make_layers(32,64,2,True)
|
65 |
+
# 64 32 16
|
66 |
+
self.layer3 = make_layers(64,128,2,True)
|
67 |
+
# 128 16 8
|
68 |
+
self.dense = nn.Sequential(
|
69 |
+
nn.Dropout(p=0.6),
|
70 |
+
nn.Linear(128*16*8, 128),
|
71 |
+
nn.BatchNorm1d(128),
|
72 |
+
nn.ELU(inplace=True)
|
73 |
+
)
|
74 |
+
# 256 1 1
|
75 |
+
self.reid = reid
|
76 |
+
self.batch_norm = nn.BatchNorm1d(128)
|
77 |
+
self.classifier = nn.Sequential(
|
78 |
+
nn.Linear(128, num_classes),
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = self.conv(x)
|
83 |
+
x = self.layer1(x)
|
84 |
+
x = self.layer2(x)
|
85 |
+
x = self.layer3(x)
|
86 |
+
|
87 |
+
x = x.view(x.size(0),-1)
|
88 |
+
if self.reid:
|
89 |
+
x = self.dense[0](x)
|
90 |
+
x = self.dense[1](x)
|
91 |
+
x = x.div(x.norm(p=2,dim=1,keepdim=True))
|
92 |
+
return x
|
93 |
+
x = self.dense(x)
|
94 |
+
# B x 128
|
95 |
+
# classifier
|
96 |
+
x = self.classifier(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
net = Net(reid=True)
|
102 |
+
x = torch.randn(4,3,128,64)
|
103 |
+
y = net(x)
|
104 |
+
import ipdb; ipdb.set_trace()
|
105 |
+
|
106 |
+
|
deep_sort/deep_sort/deep/prepare_car.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
from shutil import copyfile, copytree, rmtree, move
|
6 |
+
|
7 |
+
PATH_DATASET = './car-dataset' # 需要处理的文件夹
|
8 |
+
PATH_NEW_DATASET = './car-reid-dataset' # 处理后的文件夹
|
9 |
+
PATH_ALL_IMAGES = PATH_NEW_DATASET + '/all_images'
|
10 |
+
PATH_TRAIN = PATH_NEW_DATASET + '/train'
|
11 |
+
PATH_TEST = PATH_NEW_DATASET + '/test'
|
12 |
+
|
13 |
+
# 定义创建目录函数
|
14 |
+
def mymkdir(path):
|
15 |
+
path = path.strip() # 去除首位空格
|
16 |
+
path = path.rstrip("\\") # 去除尾部 \ 符号
|
17 |
+
isExists = os.path.exists(path) # 判断路径是否存在
|
18 |
+
if not isExists:
|
19 |
+
os.makedirs(path) # 如果不存在则创建目录
|
20 |
+
print(path + ' 创建成功')
|
21 |
+
return True
|
22 |
+
else:
|
23 |
+
# 如果目录存在则不创建,并提示目录已存在
|
24 |
+
print(path + ' 目录已存在')
|
25 |
+
return False
|
26 |
+
|
27 |
+
class BatchRename():
|
28 |
+
'''
|
29 |
+
批量重命名文件夹中的图片文件
|
30 |
+
'''
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
self.path = PATH_DATASET # 表示需要命名处理的文件夹
|
34 |
+
|
35 |
+
# 修改图像尺寸
|
36 |
+
def resize(self):
|
37 |
+
for aroot, dirs, files in os.walk(self.path):
|
38 |
+
# aroot是self.path目录下的所有子目录(含self.path),dir是self.path下所有的文件夹的列表.
|
39 |
+
filelist = files # 注意此处仅是该路径下的其中一个列表
|
40 |
+
# print('list', list)
|
41 |
+
|
42 |
+
# filelist = os.listdir(self.path) #获取文件路径
|
43 |
+
total_num = len(filelist) # 获取文件长度(个数)
|
44 |
+
|
45 |
+
for item in filelist:
|
46 |
+
if item.endswith('.jpg'): # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
|
47 |
+
src = os.path.join(os.path.abspath(aroot), item)
|
48 |
+
|
49 |
+
# 修改图片尺寸到128宽*256高
|
50 |
+
im = Image.open(src)
|
51 |
+
out = im.resize((128, 256), Image.ANTIALIAS) # resize image with high-quality
|
52 |
+
out.save(src) # 原路径保存
|
53 |
+
|
54 |
+
def rename(self):
|
55 |
+
|
56 |
+
for aroot, dirs, files in os.walk(self.path):
|
57 |
+
# aroot是self.path目录下的所有子目录(含self.path),dir是self.path下所有的文件夹的列表.
|
58 |
+
filelist = files # 注意此处仅是该路径下的其中一个列表
|
59 |
+
# print('list', list)
|
60 |
+
|
61 |
+
# filelist = os.listdir(self.path) #获取文件路径
|
62 |
+
total_num = len(filelist) # 获取文件长度(个数)
|
63 |
+
|
64 |
+
i = 1 # 表示文件的命名是从1开始的
|
65 |
+
for item in filelist:
|
66 |
+
if item.endswith('.jpg'): # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
|
67 |
+
src = os.path.join(os.path.abspath(aroot), item)
|
68 |
+
|
69 |
+
# 根据图片名创建图片目录
|
70 |
+
dirname = str(item.split('_')[0])
|
71 |
+
# 为相同车辆创建目录
|
72 |
+
#new_dir = os.path.join(self.path, '..', 'bbox_all', dirname)
|
73 |
+
new_dir = os.path.join(PATH_ALL_IMAGES, dirname)
|
74 |
+
if not os.path.isdir(new_dir):
|
75 |
+
mymkdir(new_dir)
|
76 |
+
|
77 |
+
# 获得new_dir中的图片数
|
78 |
+
num_pic = len(os.listdir(new_dir))
|
79 |
+
|
80 |
+
dst = os.path.join(os.path.abspath(new_dir),
|
81 |
+
dirname + 'C1T0001F' + str(num_pic + 1) + '.jpg')
|
82 |
+
# 处理后的格式也为jpg格式的,当然这里可以改成png格式 C1T0001F见mars.py filenames 相机ID,跟踪指数
|
83 |
+
# dst = os.path.join(os.path.abspath(self.path), '0000' + format(str(i), '0>3s') + '.jpg') 这种情况下的命名格式为0000000.jpg形式,可以自主定义想要的格式
|
84 |
+
try:
|
85 |
+
copyfile(src, dst) #os.rename(src, dst)
|
86 |
+
print ('converting %s to %s ...' % (src, dst))
|
87 |
+
i = i + 1
|
88 |
+
except:
|
89 |
+
continue
|
90 |
+
print ('total %d to rename & converted %d jpgs' % (total_num, i))
|
91 |
+
|
92 |
+
def split(self):
|
93 |
+
#---------------------------------------
|
94 |
+
#train_test
|
95 |
+
images_path = PATH_ALL_IMAGES
|
96 |
+
train_save_path = PATH_TRAIN
|
97 |
+
test_save_path = PATH_TEST
|
98 |
+
if not os.path.isdir(train_save_path):
|
99 |
+
os.mkdir(train_save_path)
|
100 |
+
os.mkdir(test_save_path)
|
101 |
+
|
102 |
+
for _, dirs, _ in os.walk(images_path, topdown=True):
|
103 |
+
for i, dir in enumerate(dirs):
|
104 |
+
for root, _, files in os.walk(images_path + '/' + dir, topdown=True):
|
105 |
+
for j, file in enumerate(files):
|
106 |
+
if(j==0): # test dataset;每个车辆的第一幅图片
|
107 |
+
print("序号:%s 文件夹: %s 图片:%s ��为测试集" % (i + 1, root, file))
|
108 |
+
src_path = root + '/' + file
|
109 |
+
dst_dir = test_save_path + '/' + dir
|
110 |
+
if not os.path.isdir(dst_dir):
|
111 |
+
os.mkdir(dst_dir)
|
112 |
+
dst_path = dst_dir + '/' + file
|
113 |
+
move(src_path, dst_path)
|
114 |
+
else:
|
115 |
+
src_path = root + '/' + file
|
116 |
+
dst_dir = train_save_path + '/' + dir
|
117 |
+
if not os.path.isdir(dst_dir):
|
118 |
+
os.mkdir(dst_dir)
|
119 |
+
dst_path = dst_dir + '/' + file
|
120 |
+
move(src_path, dst_path)
|
121 |
+
rmtree(PATH_ALL_IMAGES)
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
demo = BatchRename()
|
125 |
+
demo.resize()
|
126 |
+
demo.rename()
|
127 |
+
demo.split()
|
128 |
+
|
129 |
+
|
deep_sort/deep_sort/deep/prepare_person.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from shutil import copyfile
|
3 |
+
|
4 |
+
# You only need to change this line to your dataset download path
|
5 |
+
download_path = './Market-1501-v15.09.15'
|
6 |
+
|
7 |
+
if not os.path.isdir(download_path):
|
8 |
+
print('please change the download_path')
|
9 |
+
|
10 |
+
save_path = download_path + '/pytorch'
|
11 |
+
if not os.path.isdir(save_path):
|
12 |
+
os.mkdir(save_path)
|
13 |
+
#-----------------------------------------
|
14 |
+
#query
|
15 |
+
query_path = download_path + '/query'
|
16 |
+
query_save_path = download_path + '/pytorch/query'
|
17 |
+
if not os.path.isdir(query_save_path):
|
18 |
+
os.mkdir(query_save_path)
|
19 |
+
|
20 |
+
for root, dirs, files in os.walk(query_path, topdown=True):
|
21 |
+
for name in files:
|
22 |
+
if not name[-3:]=='jpg':
|
23 |
+
continue
|
24 |
+
ID = name.split('_')
|
25 |
+
src_path = query_path + '/' + name
|
26 |
+
dst_path = query_save_path + '/' + ID[0]
|
27 |
+
if not os.path.isdir(dst_path):
|
28 |
+
os.mkdir(dst_path)
|
29 |
+
copyfile(src_path, dst_path + '/' + name)
|
30 |
+
|
31 |
+
#-----------------------------------------
|
32 |
+
#multi-query
|
33 |
+
query_path = download_path + '/gt_bbox'
|
34 |
+
# for dukemtmc-reid, we do not need multi-query
|
35 |
+
if os.path.isdir(query_path):
|
36 |
+
query_save_path = download_path + '/pytorch/multi-query'
|
37 |
+
if not os.path.isdir(query_save_path):
|
38 |
+
os.mkdir(query_save_path)
|
39 |
+
|
40 |
+
for root, dirs, files in os.walk(query_path, topdown=True):
|
41 |
+
for name in files:
|
42 |
+
if not name[-3:]=='jpg':
|
43 |
+
continue
|
44 |
+
ID = name.split('_')
|
45 |
+
src_path = query_path + '/' + name
|
46 |
+
dst_path = query_save_path + '/' + ID[0]
|
47 |
+
if not os.path.isdir(dst_path):
|
48 |
+
os.mkdir(dst_path)
|
49 |
+
copyfile(src_path, dst_path + '/' + name)
|
50 |
+
|
51 |
+
#-----------------------------------------
|
52 |
+
#gallery
|
53 |
+
gallery_path = download_path + '/bounding_box_test'
|
54 |
+
gallery_save_path = download_path + '/pytorch/gallery'
|
55 |
+
if not os.path.isdir(gallery_save_path):
|
56 |
+
os.mkdir(gallery_save_path)
|
57 |
+
|
58 |
+
for root, dirs, files in os.walk(gallery_path, topdown=True):
|
59 |
+
for name in files:
|
60 |
+
if not name[-3:]=='jpg':
|
61 |
+
continue
|
62 |
+
ID = name.split('_')
|
63 |
+
src_path = gallery_path + '/' + name
|
64 |
+
dst_path = gallery_save_path + '/' + ID[0]
|
65 |
+
if not os.path.isdir(dst_path):
|
66 |
+
os.mkdir(dst_path)
|
67 |
+
copyfile(src_path, dst_path + '/' + name)
|
68 |
+
|
69 |
+
#---------------------------------------
|
70 |
+
#train_all
|
71 |
+
train_path = download_path + '/bounding_box_train'
|
72 |
+
train_save_path = download_path + '/pytorch/train_all'
|
73 |
+
if not os.path.isdir(train_save_path):
|
74 |
+
os.mkdir(train_save_path)
|
75 |
+
|
76 |
+
for root, dirs, files in os.walk(train_path, topdown=True):
|
77 |
+
for name in files:
|
78 |
+
if not name[-3:]=='jpg':
|
79 |
+
continue
|
80 |
+
ID = name.split('_')
|
81 |
+
src_path = train_path + '/' + name
|
82 |
+
dst_path = train_save_path + '/' + ID[0]
|
83 |
+
if not os.path.isdir(dst_path):
|
84 |
+
os.mkdir(dst_path)
|
85 |
+
copyfile(src_path, dst_path + '/' + name)
|
86 |
+
|
87 |
+
|
88 |
+
#---------------------------------------
|
89 |
+
#train_val
|
90 |
+
train_path = download_path + '/bounding_box_train'
|
91 |
+
train_save_path = download_path + '/pytorch/train'
|
92 |
+
val_save_path = download_path + '/pytorch/test'
|
93 |
+
if not os.path.isdir(train_save_path):
|
94 |
+
os.mkdir(train_save_path)
|
95 |
+
os.mkdir(val_save_path)
|
96 |
+
|
97 |
+
for root, dirs, files in os.walk(train_path, topdown=True):
|
98 |
+
for name in files:
|
99 |
+
if not name[-3:]=='jpg':
|
100 |
+
continue
|
101 |
+
ID = name.split('_')
|
102 |
+
src_path = train_path + '/' + name
|
103 |
+
dst_path = train_save_path + '/' + ID[0]
|
104 |
+
if not os.path.isdir(dst_path):
|
105 |
+
os.mkdir(dst_path)
|
106 |
+
dst_path = val_save_path + '/' + ID[0] #first image is used as val image
|
107 |
+
os.mkdir(dst_path)
|
108 |
+
copyfile(src_path, dst_path + '/' + name)
|
deep_sort/deep_sort/deep/test.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.backends.cudnn as cudnn
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
from model import Net
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(description="Train on market1501")
|
11 |
+
parser.add_argument("--data-dir",default='data',type=str)
|
12 |
+
parser.add_argument("--no-cuda",action="store_true")
|
13 |
+
parser.add_argument("--gpu-id",default=0,type=int)
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
# device
|
17 |
+
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
18 |
+
if torch.cuda.is_available() and not args.no_cuda:
|
19 |
+
cudnn.benchmark = True
|
20 |
+
|
21 |
+
# data loader
|
22 |
+
root = args.data_dir
|
23 |
+
query_dir = os.path.join(root,"query")
|
24 |
+
gallery_dir = os.path.join(root,"gallery")
|
25 |
+
transform = torchvision.transforms.Compose([
|
26 |
+
torchvision.transforms.Resize((128,64)),
|
27 |
+
torchvision.transforms.ToTensor(),
|
28 |
+
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
29 |
+
])
|
30 |
+
queryloader = torch.utils.data.DataLoader(
|
31 |
+
torchvision.datasets.ImageFolder(query_dir, transform=transform),
|
32 |
+
batch_size=64, shuffle=False
|
33 |
+
)
|
34 |
+
galleryloader = torch.utils.data.DataLoader(
|
35 |
+
torchvision.datasets.ImageFolder(gallery_dir, transform=transform),
|
36 |
+
batch_size=64, shuffle=False
|
37 |
+
)
|
38 |
+
|
39 |
+
# net definition
|
40 |
+
net = Net(reid=True)
|
41 |
+
assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
|
42 |
+
print('Loading from checkpoint/ckpt.t7')
|
43 |
+
checkpoint = torch.load("./checkpoint/ckpt.t7")
|
44 |
+
net_dict = checkpoint['net_dict']
|
45 |
+
net.load_state_dict(net_dict, strict=False)
|
46 |
+
net.eval()
|
47 |
+
net.to(device)
|
48 |
+
|
49 |
+
# compute features
|
50 |
+
query_features = torch.tensor([]).float()
|
51 |
+
query_labels = torch.tensor([]).long()
|
52 |
+
gallery_features = torch.tensor([]).float()
|
53 |
+
gallery_labels = torch.tensor([]).long()
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
for idx,(inputs,labels) in enumerate(queryloader):
|
57 |
+
inputs = inputs.to(device)
|
58 |
+
features = net(inputs).cpu()
|
59 |
+
query_features = torch.cat((query_features, features), dim=0)
|
60 |
+
query_labels = torch.cat((query_labels, labels))
|
61 |
+
|
62 |
+
for idx,(inputs,labels) in enumerate(galleryloader):
|
63 |
+
inputs = inputs.to(device)
|
64 |
+
features = net(inputs).cpu()
|
65 |
+
gallery_features = torch.cat((gallery_features, features), dim=0)
|
66 |
+
gallery_labels = torch.cat((gallery_labels, labels))
|
67 |
+
|
68 |
+
gallery_labels -= 2
|
69 |
+
|
70 |
+
# save features
|
71 |
+
features = {
|
72 |
+
"qf": query_features,
|
73 |
+
"ql": query_labels,
|
74 |
+
"gf": gallery_features,
|
75 |
+
"gl": gallery_labels
|
76 |
+
}
|
77 |
+
torch.save(features,"features.pth")
|
deep_sort/deep_sort/deep/train.jpg
ADDED
![]() |
deep_sort/deep_sort/deep/train.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
import torchvision
|
10 |
+
|
11 |
+
from model import Net
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(description="Train on market1501")
|
14 |
+
parser.add_argument("--data-dir",default='data',type=str)
|
15 |
+
parser.add_argument("--no-cuda",action="store_true")
|
16 |
+
parser.add_argument("--gpu-id",default=0,type=int)
|
17 |
+
parser.add_argument("--lr",default=0.1, type=float)
|
18 |
+
parser.add_argument("--interval",'-i',default=20,type=int)
|
19 |
+
parser.add_argument('--resume', '-r',action='store_true')
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# device
|
23 |
+
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
24 |
+
if torch.cuda.is_available() and not args.no_cuda:
|
25 |
+
cudnn.benchmark = True
|
26 |
+
|
27 |
+
# data loading
|
28 |
+
root = args.data_dir
|
29 |
+
train_dir = os.path.join(root,"train")
|
30 |
+
test_dir = os.path.join(root,"test")
|
31 |
+
|
32 |
+
transform_train = torchvision.transforms.Compose([
|
33 |
+
torchvision.transforms.RandomCrop((128,64),padding=4),
|
34 |
+
torchvision.transforms.RandomHorizontalFlip(),
|
35 |
+
torchvision.transforms.ToTensor(),
|
36 |
+
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
37 |
+
])
|
38 |
+
transform_test = torchvision.transforms.Compose([
|
39 |
+
torchvision.transforms.Resize((128,64)),
|
40 |
+
torchvision.transforms.ToTensor(),
|
41 |
+
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
42 |
+
])
|
43 |
+
trainloader = torch.utils.data.DataLoader(
|
44 |
+
torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
|
45 |
+
batch_size=64,shuffle=True
|
46 |
+
)
|
47 |
+
testloader = torch.utils.data.DataLoader(
|
48 |
+
torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
|
49 |
+
batch_size=64,shuffle=True
|
50 |
+
)
|
51 |
+
num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))
|
52 |
+
print("num_classes = %s" %num_classes)
|
53 |
+
|
54 |
+
# net definition
|
55 |
+
start_epoch = 0
|
56 |
+
net = Net(num_classes=num_classes)
|
57 |
+
if args.resume:
|
58 |
+
assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
|
59 |
+
print('Loading from checkpoint/ckpt.t7')
|
60 |
+
checkpoint = torch.load("./checkpoint/ckpt.t7")
|
61 |
+
# import ipdb; ipdb.set_trace()
|
62 |
+
net_dict = checkpoint['net_dict']
|
63 |
+
net.load_state_dict(net_dict)
|
64 |
+
best_acc = checkpoint['acc']
|
65 |
+
start_epoch = checkpoint['epoch']
|
66 |
+
net.to(device)
|
67 |
+
|
68 |
+
# loss and optimizer
|
69 |
+
criterion = torch.nn.CrossEntropyLoss()
|
70 |
+
optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
|
71 |
+
best_acc = 0.
|
72 |
+
|
73 |
+
# train function for each epoch
|
74 |
+
def train(epoch):
|
75 |
+
print("\nEpoch : %d"%(epoch+1))
|
76 |
+
net.train()
|
77 |
+
training_loss = 0.
|
78 |
+
train_loss = 0.
|
79 |
+
correct = 0
|
80 |
+
total = 0
|
81 |
+
interval = args.interval
|
82 |
+
start = time.time()
|
83 |
+
for idx, (inputs, labels) in enumerate(trainloader):
|
84 |
+
# forward
|
85 |
+
inputs,labels = inputs.to(device),labels.to(device)
|
86 |
+
outputs = net(inputs)
|
87 |
+
loss = criterion(outputs, labels)
|
88 |
+
|
89 |
+
# backward
|
90 |
+
optimizer.zero_grad()
|
91 |
+
loss.backward()
|
92 |
+
optimizer.step()
|
93 |
+
|
94 |
+
# accumurating
|
95 |
+
training_loss += loss.item()
|
96 |
+
train_loss += loss.item()
|
97 |
+
correct += outputs.max(dim=1)[1].eq(labels).sum().item()
|
98 |
+
total += labels.size(0)
|
99 |
+
|
100 |
+
# print
|
101 |
+
if (idx+1)%interval == 0:
|
102 |
+
end = time.time()
|
103 |
+
print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
|
104 |
+
100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total
|
105 |
+
))
|
106 |
+
training_loss = 0.
|
107 |
+
start = time.time()
|
108 |
+
|
109 |
+
return train_loss/len(trainloader), 1.- correct/total
|
110 |
+
|
111 |
+
def test(epoch):
|
112 |
+
global best_acc
|
113 |
+
net.eval()
|
114 |
+
test_loss = 0.
|
115 |
+
correct = 0
|
116 |
+
total = 0
|
117 |
+
start = time.time()
|
118 |
+
with torch.no_grad():
|
119 |
+
for idx, (inputs, labels) in enumerate(testloader):
|
120 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
121 |
+
outputs = net(inputs)
|
122 |
+
loss = criterion(outputs, labels)
|
123 |
+
|
124 |
+
test_loss += loss.item()
|
125 |
+
correct += outputs.max(dim=1)[1].eq(labels).sum().item()
|
126 |
+
total += labels.size(0)
|
127 |
+
|
128 |
+
print("Testing ...")
|
129 |
+
end = time.time()
|
130 |
+
print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
|
131 |
+
100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total
|
132 |
+
))
|
133 |
+
|
134 |
+
# saving checkpoint
|
135 |
+
acc = 100.*correct/total
|
136 |
+
if acc > best_acc:
|
137 |
+
best_acc = acc
|
138 |
+
print("Saving parameters to checkpoint/ckpt.t7")
|
139 |
+
checkpoint = {
|
140 |
+
'net_dict':net.state_dict(),
|
141 |
+
'acc':acc,
|
142 |
+
'epoch':epoch,
|
143 |
+
}
|
144 |
+
if not os.path.isdir('checkpoint'):
|
145 |
+
os.mkdir('checkpoint')
|
146 |
+
torch.save(checkpoint, './checkpoint/ckpt.t7')
|
147 |
+
|
148 |
+
return test_loss/len(testloader), 1.- correct/total
|
149 |
+
|
150 |
+
# plot figure
|
151 |
+
x_epoch = []
|
152 |
+
record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]}
|
153 |
+
fig = plt.figure()
|
154 |
+
ax0 = fig.add_subplot(121, title="loss")
|
155 |
+
ax1 = fig.add_subplot(122, title="top1err")
|
156 |
+
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
157 |
+
global record
|
158 |
+
record['train_loss'].append(train_loss)
|
159 |
+
record['train_err'].append(train_err)
|
160 |
+
record['test_loss'].append(test_loss)
|
161 |
+
record['test_err'].append(test_err)
|
162 |
+
|
163 |
+
x_epoch.append(epoch)
|
164 |
+
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
165 |
+
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
166 |
+
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
167 |
+
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
168 |
+
if epoch == 0:
|
169 |
+
ax0.legend()
|
170 |
+
ax1.legend()
|
171 |
+
fig.savefig("train.jpg")
|
172 |
+
|
173 |
+
# lr decay
|
174 |
+
def lr_decay():
|
175 |
+
global optimizer
|
176 |
+
for params in optimizer.param_groups:
|
177 |
+
params['lr'] *= 0.1
|
178 |
+
lr = params['lr']
|
179 |
+
print("Learning rate adjusted to {}".format(lr))
|
180 |
+
|
181 |
+
def main():
|
182 |
+
total_epoches = 40
|
183 |
+
for epoch in range(start_epoch, start_epoch+total_epoches):
|
184 |
+
train_loss, train_err = train(epoch)
|
185 |
+
test_loss, test_err = test(epoch)
|
186 |
+
draw_curve(epoch, train_loss, train_err, test_loss, test_err)
|
187 |
+
if (epoch+1)%(total_epoches//2)==0:
|
188 |
+
lr_decay()
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
main()
|
deep_sort/deep_sort/deep_sort.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .deep.feature_extractor import Extractor
|
5 |
+
from .sort.nn_matching import NearestNeighborDistanceMetric
|
6 |
+
from .sort.preprocessing import non_max_suppression
|
7 |
+
from .sort.detection import Detection
|
8 |
+
from .sort.tracker import Tracker
|
9 |
+
|
10 |
+
|
11 |
+
__all__ = ['DeepSort'] # __all__ 提供了暴露接口用的”白名单“
|
12 |
+
|
13 |
+
|
14 |
+
class DeepSort(object):
|
15 |
+
def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True):
|
16 |
+
self.min_confidence = min_confidence # 检测结果置信度阈值
|
17 |
+
self.nms_max_overlap = nms_max_overlap # 非极大抑制阈值,设置为1代表不进行抑制
|
18 |
+
|
19 |
+
self.extractor = Extractor(model_path, use_cuda=use_cuda) # 用于提取一个batch图片对应的特征
|
20 |
+
|
21 |
+
max_cosine_distance = max_dist # 最大余弦距离,用于级联匹配,如果大于该阈值,则忽略
|
22 |
+
nn_budget = 100 # 每个类别gallery最多的外观描述子的个数,如果超过,删除旧的
|
23 |
+
# NearestNeighborDistanceMetric 最近邻距离度量
|
24 |
+
# 对于每个目标,返回到目前为止已观察到的任何样本的最近距离(欧式或余弦)。
|
25 |
+
# 由距离度量方法构造一个 Tracker。
|
26 |
+
# 第一个参数可选'cosine' or 'euclidean'
|
27 |
+
self.metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
|
28 |
+
self.tracker = Tracker(self.metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
|
29 |
+
|
30 |
+
def update(self, bbox_xywh, confidences, ori_img):
|
31 |
+
self.height, self.width = ori_img.shape[:2]
|
32 |
+
# generate detections
|
33 |
+
# 从原图中抠取bbox对应图片并计算得到相应的特征
|
34 |
+
features = self._get_features(bbox_xywh, ori_img)
|
35 |
+
bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
|
36 |
+
# 筛选掉小于min_confidence的目标,并构造一个Detection对象构成的列表
|
37 |
+
detections = [Detection(bbox_tlwh[i], conf, features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence]
|
38 |
+
|
39 |
+
# run on non-maximum supression
|
40 |
+
boxes = np.array([d.tlwh for d in detections])
|
41 |
+
scores = np.array([d.confidence for d in detections])
|
42 |
+
indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
|
43 |
+
detections = [detections[i] for i in indices]
|
44 |
+
|
45 |
+
# update tracker
|
46 |
+
self.tracker.predict() # 将跟踪状态分布向前传播一步
|
47 |
+
self.tracker.update(detections) # 执行测量更新和跟踪管理
|
48 |
+
|
49 |
+
# output bbox identities
|
50 |
+
outputs = []
|
51 |
+
for track in self.tracker.tracks:
|
52 |
+
if not track.is_confirmed() or track.time_since_update > 1:
|
53 |
+
continue
|
54 |
+
box = track.to_tlwh()
|
55 |
+
x1,y1,x2,y2 = self._tlwh_to_xyxy(box)
|
56 |
+
track_id = track.track_id
|
57 |
+
outputs.append(np.array([x1,y1,x2,y2,track_id], dtype=np.int16))
|
58 |
+
if len(outputs) > 0:
|
59 |
+
outputs = np.stack(outputs,axis=0)
|
60 |
+
return outputs
|
61 |
+
|
62 |
+
|
63 |
+
"""
|
64 |
+
TODO:
|
65 |
+
Convert bbox from xc_yc_w_h to xtl_ytl_w_h
|
66 |
+
Thanks [email protected] for reporting this bug!
|
67 |
+
"""
|
68 |
+
#将bbox的[x,y,w,h] 转换成[t,l,w,h]
|
69 |
+
@staticmethod
|
70 |
+
def _xywh_to_tlwh(bbox_xywh):
|
71 |
+
if isinstance(bbox_xywh, np.ndarray):
|
72 |
+
bbox_tlwh = bbox_xywh.copy()
|
73 |
+
elif isinstance(bbox_xywh, torch.Tensor):
|
74 |
+
bbox_tlwh = bbox_xywh.clone()
|
75 |
+
bbox_tlwh[:,0] = bbox_xywh[:,0] - bbox_xywh[:,2]/2.
|
76 |
+
bbox_tlwh[:,1] = bbox_xywh[:,1] - bbox_xywh[:,3]/2.
|
77 |
+
return bbox_tlwh
|
78 |
+
|
79 |
+
#将bbox的[x,y,w,h] 转换成[x1,y1,x2,y2]
|
80 |
+
#某些数据集例如 pascal_voc 的标注方式是采用[x,y,w,h]
|
81 |
+
"""Convert [x y w h] box format to [x1 y1 x2 y2] format."""
|
82 |
+
def _xywh_to_xyxy(self, bbox_xywh):
|
83 |
+
x,y,w,h = bbox_xywh
|
84 |
+
x1 = max(int(x-w/2),0)
|
85 |
+
x2 = min(int(x+w/2),self.width-1)
|
86 |
+
y1 = max(int(y-h/2),0)
|
87 |
+
y2 = min(int(y+h/2),self.height-1)
|
88 |
+
return x1,y1,x2,y2
|
89 |
+
|
90 |
+
def _tlwh_to_xyxy(self, bbox_tlwh):
|
91 |
+
"""
|
92 |
+
TODO:
|
93 |
+
Convert bbox from xtl_ytl_w_h to xc_yc_w_h
|
94 |
+
Thanks [email protected] for reporting this bug!
|
95 |
+
"""
|
96 |
+
x,y,w,h = bbox_tlwh
|
97 |
+
x1 = max(int(x),0)
|
98 |
+
x2 = min(int(x+w),self.width-1)
|
99 |
+
y1 = max(int(y),0)
|
100 |
+
y2 = min(int(y+h),self.height-1)
|
101 |
+
return x1,y1,x2,y2
|
102 |
+
|
103 |
+
def _xyxy_to_tlwh(self, bbox_xyxy):
|
104 |
+
x1,y1,x2,y2 = bbox_xyxy
|
105 |
+
|
106 |
+
t = x1
|
107 |
+
l = y1
|
108 |
+
w = int(x2-x1)
|
109 |
+
h = int(y2-y1)
|
110 |
+
return t,l,w,h
|
111 |
+
|
112 |
+
# 获取抠图部分的特征
|
113 |
+
def _get_features(self, bbox_xywh, ori_img):
|
114 |
+
im_crops = []
|
115 |
+
for box in bbox_xywh:
|
116 |
+
x1,y1,x2,y2 = self._xywh_to_xyxy(box)
|
117 |
+
im = ori_img[y1:y2,x1:x2] # 抠图部分
|
118 |
+
im_crops.append(im)
|
119 |
+
if im_crops:
|
120 |
+
features = self.extractor(im_crops) # 对抠图部分提取特征
|
121 |
+
else:
|
122 |
+
features = np.array([])
|
123 |
+
return features
|
124 |
+
|
125 |
+
|
deep_sort/deep_sort/sort/__init__.py
ADDED
File without changes
|
deep_sort/deep_sort/sort/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (168 Bytes). View file
|
|
deep_sort/deep_sort/sort/__pycache__/detection.cpython-310.pyc
ADDED
Binary file (1.91 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc
ADDED
Binary file (2.95 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc
ADDED
Binary file (7.95 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc
ADDED
Binary file (8.19 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc
ADDED
Binary file (7.45 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-310.pyc
ADDED
Binary file (1.92 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/track.cpython-310.pyc
ADDED
Binary file (6.89 kB). View file
|
|
deep_sort/deep_sort/sort/__pycache__/tracker.cpython-310.pyc
ADDED
Binary file (5.71 kB). View file
|
|
deep_sort/deep_sort/sort/detection.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class Detection(object):
|
6 |
+
"""
|
7 |
+
This class represents a bounding box detection in a single image.
|
8 |
+
|
9 |
+
Parameters
|
10 |
+
----------
|
11 |
+
tlwh : array_like
|
12 |
+
Bounding box in format `(top left x, top left y, width, height)`.
|
13 |
+
confidence : float
|
14 |
+
Detector confidence score.
|
15 |
+
feature : array_like
|
16 |
+
A feature vector that describes the object contained in this image.
|
17 |
+
|
18 |
+
Attributes
|
19 |
+
----------
|
20 |
+
tlwh : ndarray
|
21 |
+
Bounding box in format `(top left x, top left y, width, height)`.
|
22 |
+
confidence : ndarray
|
23 |
+
Detector confidence score.
|
24 |
+
feature : ndarray | NoneType
|
25 |
+
A feature vector that describes the object contained in this image.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, tlwh, confidence, feature):
|
30 |
+
self.tlwh = np.asarray(tlwh, dtype=np.float32)
|
31 |
+
self.confidence = float(confidence)
|
32 |
+
self.feature = np.asarray(feature, dtype=np.float32)
|
33 |
+
|
34 |
+
def to_tlbr(self):
|
35 |
+
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
36 |
+
`(top left, bottom right)`.
|
37 |
+
"""
|
38 |
+
ret = self.tlwh.copy()
|
39 |
+
ret[2:] += ret[:2]
|
40 |
+
return ret
|
41 |
+
|
42 |
+
def to_xyah(self):
|
43 |
+
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
44 |
+
height)`, where the aspect ratio is `width / height`.
|
45 |
+
"""
|
46 |
+
ret = self.tlwh.copy()
|
47 |
+
ret[:2] += ret[2:] / 2
|
48 |
+
ret[2] /= ret[3]
|
49 |
+
return ret
|
deep_sort/deep_sort/sort/iou_matching.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
from __future__ import absolute_import
|
3 |
+
import numpy as np
|
4 |
+
from . import linear_assignment
|
5 |
+
|
6 |
+
#计算两个框的IOU
|
7 |
+
def iou(bbox, candidates):
|
8 |
+
"""Computer intersection over union.
|
9 |
+
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
bbox : ndarray
|
13 |
+
A bounding box in format `(top left x, top left y, width, height)`.
|
14 |
+
candidates : ndarray
|
15 |
+
A matrix of candidate bounding boxes (one per row) in the same format
|
16 |
+
as `bbox`.
|
17 |
+
|
18 |
+
Returns
|
19 |
+
-------
|
20 |
+
ndarray
|
21 |
+
The intersection over union in [0, 1] between the `bbox` and each
|
22 |
+
candidate. A higher score means a larger fraction of the `bbox` is
|
23 |
+
occluded by the candidate.
|
24 |
+
|
25 |
+
"""
|
26 |
+
bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
|
27 |
+
candidates_tl = candidates[:, :2]
|
28 |
+
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
29 |
+
|
30 |
+
# np.c_ Translates slice objects to concatenation along the second axis.
|
31 |
+
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
32 |
+
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
33 |
+
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
34 |
+
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
35 |
+
wh = np.maximum(0., br - tl)
|
36 |
+
|
37 |
+
area_intersection = wh.prod(axis=1)
|
38 |
+
area_bbox = bbox[2:].prod()
|
39 |
+
area_candidates = candidates[:, 2:].prod(axis=1)
|
40 |
+
return area_intersection / (area_bbox + area_candidates - area_intersection)
|
41 |
+
|
42 |
+
# 计算tracks和detections之间的IOU距离成本矩阵
|
43 |
+
def iou_cost(tracks, detections, track_indices=None,
|
44 |
+
detection_indices=None):
|
45 |
+
"""An intersection over union distance metric.
|
46 |
+
|
47 |
+
用于计算tracks和detections之间的iou距离矩阵
|
48 |
+
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
tracks : List[deep_sort.track.Track]
|
52 |
+
A list of tracks.
|
53 |
+
detections : List[deep_sort.detection.Detection]
|
54 |
+
A list of detections.
|
55 |
+
track_indices : Optional[List[int]]
|
56 |
+
A list of indices to tracks that should be matched. Defaults to
|
57 |
+
all `tracks`.
|
58 |
+
detection_indices : Optional[List[int]]
|
59 |
+
A list of indices to detections that should be matched. Defaults
|
60 |
+
to all `detections`.
|
61 |
+
|
62 |
+
Returns
|
63 |
+
-------
|
64 |
+
ndarray
|
65 |
+
Returns a cost matrix of shape
|
66 |
+
len(track_indices), len(detection_indices) where entry (i, j) is
|
67 |
+
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
68 |
+
|
69 |
+
"""
|
70 |
+
if track_indices is None:
|
71 |
+
track_indices = np.arange(len(tracks))
|
72 |
+
if detection_indices is None:
|
73 |
+
detection_indices = np.arange(len(detections))
|
74 |
+
|
75 |
+
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
76 |
+
for row, track_idx in enumerate(track_indices):
|
77 |
+
if tracks[track_idx].time_since_update > 1:
|
78 |
+
cost_matrix[row, :] = linear_assignment.INFTY_COST
|
79 |
+
continue
|
80 |
+
|
81 |
+
bbox = tracks[track_idx].to_tlwh()
|
82 |
+
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
|
83 |
+
cost_matrix[row, :] = 1. - iou(bbox, candidates)
|
84 |
+
return cost_matrix
|
deep_sort/deep_sort/sort/kalman_filter.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
import scipy.linalg
|
4 |
+
|
5 |
+
|
6 |
+
"""
|
7 |
+
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
8 |
+
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
9 |
+
function and used as Mahalanobis gating threshold.
|
10 |
+
"""
|
11 |
+
chi2inv95 = {
|
12 |
+
1: 3.8415,
|
13 |
+
2: 5.9915,
|
14 |
+
3: 7.8147,
|
15 |
+
4: 9.4877,
|
16 |
+
5: 11.070,
|
17 |
+
6: 12.592,
|
18 |
+
7: 14.067,
|
19 |
+
8: 15.507,
|
20 |
+
9: 16.919}
|
21 |
+
|
22 |
+
'''
|
23 |
+
卡尔曼滤波分为两个阶段:
|
24 |
+
(1) 预测track在下一时刻的位置,
|
25 |
+
(2) 基于detection来更新预测的位置。
|
26 |
+
'''
|
27 |
+
class KalmanFilter(object):
|
28 |
+
"""
|
29 |
+
A simple Kalman filter for tracking bounding boxes in image space.
|
30 |
+
|
31 |
+
The 8-dimensional state space
|
32 |
+
|
33 |
+
x, y, a, h, vx, vy, va, vh
|
34 |
+
|
35 |
+
contains the bounding box center position (x, y), aspect ratio a, height h,
|
36 |
+
and their respective velocities.
|
37 |
+
|
38 |
+
Object motion follows a constant velocity model. The bounding box location
|
39 |
+
(x, y, a, h) is taken as direct observation of the state space (linear
|
40 |
+
observation model).
|
41 |
+
|
42 |
+
对于每个轨迹,由一个 KalmanFilter 预测状态分布。每个轨迹记录自己的均值和方差作为滤波器输入。
|
43 |
+
|
44 |
+
8维状态空间[x, y, a, h, vx, vy, va, vh]包含边界框中心位置(x, y),纵横比a,高度h和它们各自的速度。
|
45 |
+
物体运动遵循恒速模型。 边界框位置(x, y, a, h)被视为状态空间的直接观察(线性观察模型)
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self):
|
50 |
+
ndim, dt = 4, 1.
|
51 |
+
|
52 |
+
# Create Kalman filter model matrices.
|
53 |
+
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
54 |
+
for i in range(ndim):
|
55 |
+
self._motion_mat[i, ndim + i] = dt
|
56 |
+
self._update_mat = np.eye(ndim, 2 * ndim)
|
57 |
+
|
58 |
+
# Motion and observation uncertainty are chosen relative to the current
|
59 |
+
# state estimate. These weights control the amount of uncertainty in
|
60 |
+
# the model. This is a bit hacky.
|
61 |
+
# 依据当前状态估计(高度)选择运动和观测不确定性。这些权重控制模型中的不确定性。
|
62 |
+
self._std_weight_position = 1. / 20
|
63 |
+
self._std_weight_velocity = 1. / 160
|
64 |
+
|
65 |
+
def initiate(self, measurement):
|
66 |
+
"""Create track from unassociated measurement.
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
measurement : ndarray
|
71 |
+
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
72 |
+
aspect ratio a, and height h.
|
73 |
+
|
74 |
+
Returns
|
75 |
+
-------
|
76 |
+
(ndarray, ndarray)
|
77 |
+
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
78 |
+
dimensional) of the new track. Unobserved velocities are initialized
|
79 |
+
to 0 mean.
|
80 |
+
|
81 |
+
"""
|
82 |
+
|
83 |
+
|
84 |
+
mean_pos = measurement
|
85 |
+
mean_vel = np.zeros_like(mean_pos)
|
86 |
+
# Translates slice objects to concatenation along the first axis
|
87 |
+
mean = np.r_[mean_pos, mean_vel]
|
88 |
+
|
89 |
+
# 由测量初始化均值向量(8维)和协方差矩阵(8x8维)
|
90 |
+
std = [
|
91 |
+
2 * self._std_weight_position * measurement[3],
|
92 |
+
2 * self._std_weight_position * measurement[3],
|
93 |
+
1e-2,
|
94 |
+
2 * self._std_weight_position * measurement[3],
|
95 |
+
10 * self._std_weight_velocity * measurement[3],
|
96 |
+
10 * self._std_weight_velocity * measurement[3],
|
97 |
+
1e-5,
|
98 |
+
10 * self._std_weight_velocity * measurement[3]]
|
99 |
+
covariance = np.diag(np.square(std))
|
100 |
+
return mean, covariance
|
101 |
+
|
102 |
+
def predict(self, mean, covariance):
|
103 |
+
"""Run Kalman filter prediction step.
|
104 |
+
|
105 |
+
Parameters
|
106 |
+
----------
|
107 |
+
mean : ndarray
|
108 |
+
The 8 dimensional mean vector of the object state at the previous
|
109 |
+
time step.
|
110 |
+
covariance : ndarray
|
111 |
+
The 8x8 dimensional covariance matrix of the object state at the
|
112 |
+
previous time step.
|
113 |
+
|
114 |
+
Returns
|
115 |
+
-------
|
116 |
+
(ndarray, ndarray)
|
117 |
+
Returns the mean vector and covariance matrix of the predicted
|
118 |
+
state. Unobserved velocities are initialized to 0 mean.
|
119 |
+
|
120 |
+
"""
|
121 |
+
#卡尔曼滤波器由目标上一时刻的均值和协方差进行预测。
|
122 |
+
std_pos = [
|
123 |
+
self._std_weight_position * mean[3],
|
124 |
+
self._std_weight_position * mean[3],
|
125 |
+
1e-2,
|
126 |
+
self._std_weight_position * mean[3]]
|
127 |
+
std_vel = [
|
128 |
+
self._std_weight_velocity * mean[3],
|
129 |
+
self._std_weight_velocity * mean[3],
|
130 |
+
1e-5,
|
131 |
+
self._std_weight_velocity * mean[3]]
|
132 |
+
|
133 |
+
# 初始化噪声矩阵Q;np.r_ 按列连接两个矩阵
|
134 |
+
# motion_cov是过程噪声 W_k的 协方差矩阵Qk
|
135 |
+
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
136 |
+
|
137 |
+
# Update time state x' = Fx (1)
|
138 |
+
# x为track在t-1时刻的均值,F称为状态转移矩阵,该公式预测t时刻的x'
|
139 |
+
# self._motion_mat为F_k是作用在 x_{k-1}上的状态变换模型
|
140 |
+
mean = np.dot(self._motion_mat, mean)
|
141 |
+
# Calculate error covariance P' = FPF^T+Q (2)
|
142 |
+
# P为track在t-1时刻的协方差,Q为系统的噪声矩阵,代表整个系统的可靠程度,一般初始化为很小的值,
|
143 |
+
# 该公式预测t时刻的P'
|
144 |
+
# covariance为P_{k|k} ,后验估计误差协方差矩阵,度量估计值的精确程度
|
145 |
+
covariance = np.linalg.multi_dot((
|
146 |
+
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
147 |
+
|
148 |
+
return mean, covariance
|
149 |
+
|
150 |
+
def project(self, mean, covariance):
|
151 |
+
"""Project state distribution to measurement space.
|
152 |
+
投影状态分布到测量空间
|
153 |
+
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
mean : ndarray
|
157 |
+
The state's mean vector (8 dimensional array).
|
158 |
+
covariance : ndarray
|
159 |
+
The state's covariance matrix (8x8 dimensional).
|
160 |
+
|
161 |
+
mean:ndarray,状态的平均向量(8维数组)。
|
162 |
+
covariance:ndarray,状态的协方差矩阵(8x8维)。
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
(ndarray, ndarray)
|
167 |
+
Returns the projected mean and covariance matrix of the given state
|
168 |
+
estimate.
|
169 |
+
|
170 |
+
返回(ndarray,ndarray),返回给定状态估计的投影平均值和协方差矩阵
|
171 |
+
|
172 |
+
"""
|
173 |
+
# 在公式4中,R为检测器的噪声矩阵,它是一个4x4的对角矩阵,
|
174 |
+
# 对角线上的值分别为中心点两个坐标以及宽高的噪声,
|
175 |
+
# 以任意值初始化,一般设置宽高的噪声大于中心点的噪声,
|
176 |
+
# 该公式先将协方差矩阵P'映射到检测空间,然后再加上噪声矩阵R;
|
177 |
+
std = [
|
178 |
+
self._std_weight_position * mean[3],
|
179 |
+
self._std_weight_position * mean[3],
|
180 |
+
1e-1,
|
181 |
+
self._std_weight_position * mean[3]]
|
182 |
+
|
183 |
+
# R为测量过程中噪声的协方差;初始化噪声矩阵R
|
184 |
+
innovation_cov = np.diag(np.square(std))
|
185 |
+
|
186 |
+
# 将均值向量映射到检测空间,即 Hx'
|
187 |
+
mean = np.dot(self._update_mat, mean)
|
188 |
+
# 将协方差矩阵映射到检测空间,即 HP'H^T
|
189 |
+
covariance = np.linalg.multi_dot((
|
190 |
+
self._update_mat, covariance, self._update_mat.T))
|
191 |
+
return mean, covariance + innovation_cov # 公式(4)
|
192 |
+
|
193 |
+
def update(self, mean, covariance, measurement):
|
194 |
+
"""Run Kalman filter correction step.
|
195 |
+
通过估计值和观测值估计最新结果
|
196 |
+
|
197 |
+
Parameters
|
198 |
+
----------
|
199 |
+
mean : ndarray
|
200 |
+
The predicted state's mean vector (8 dimensional).
|
201 |
+
covariance : ndarray
|
202 |
+
The state's covariance matrix (8x8 dimensional).
|
203 |
+
measurement : ndarray
|
204 |
+
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
205 |
+
is the center position, a the aspect ratio, and h the height of the
|
206 |
+
bounding box.
|
207 |
+
|
208 |
+
Returns
|
209 |
+
-------
|
210 |
+
(ndarray, ndarray)
|
211 |
+
Returns the measurement-corrected state distribution.
|
212 |
+
|
213 |
+
"""
|
214 |
+
# 将均值和协方差映射到检测空间,得到 Hx'和S
|
215 |
+
projected_mean, projected_cov = self.project(mean, covariance)
|
216 |
+
|
217 |
+
# 矩阵分解
|
218 |
+
chol_factor, lower = scipy.linalg.cho_factor(
|
219 |
+
projected_cov, lower=True, check_finite=False)
|
220 |
+
# 计算卡尔曼增益K;相当于求解公式(5)
|
221 |
+
# 公式5计算卡尔曼增益K,卡尔曼增益用于估计误差的重要程度
|
222 |
+
# 求解卡尔曼滤波增益K 用到了cholesky矩阵分解加快求解;
|
223 |
+
# 公式5的右边有一个S的逆,如果S矩阵很大,S的逆求解消耗时间太大,
|
224 |
+
# 所以代码中把公式两边同时乘上S,右边的S*S的逆变成了单位矩阵,转化成AX=B形式求解。
|
225 |
+
kalman_gain = scipy.linalg.cho_solve(
|
226 |
+
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
227 |
+
check_finite=False).T
|
228 |
+
# y = z - Hx' (3)
|
229 |
+
# 在公式3中,z为detection的均值向量,不包含速度变化值,即z=[cx, cy, r, h],
|
230 |
+
# H称为测量矩阵,它将track的均值向量x'映射到检测空间,该公式计算detection和track的均值误差
|
231 |
+
innovation = measurement - projected_mean
|
232 |
+
|
233 |
+
# 更新后的均值向量 x = x' + Ky (6)
|
234 |
+
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
235 |
+
# 更新后的协方差矩阵 P = (I - KH)P' (7)
|
236 |
+
new_covariance = covariance - np.linalg.multi_dot((
|
237 |
+
kalman_gain, projected_cov, kalman_gain.T))
|
238 |
+
return new_mean, new_covariance
|
239 |
+
|
240 |
+
def gating_distance(self, mean, covariance, measurements,
|
241 |
+
only_position=False):
|
242 |
+
"""Compute gating distance between state distribution and measurements.
|
243 |
+
|
244 |
+
A suitable distance threshold can be obtained from `chi2inv95`. If
|
245 |
+
`only_position` is False, the chi-square distribution has 4 degrees of
|
246 |
+
freedom, otherwise 2.
|
247 |
+
|
248 |
+
Parameters
|
249 |
+
----------
|
250 |
+
mean : ndarray
|
251 |
+
Mean vector over the state distribution (8 dimensional).
|
252 |
+
状态分布上的平均向量(8维)
|
253 |
+
covariance : ndarray
|
254 |
+
Covariance of the state distribution (8x8 dimensional).
|
255 |
+
状态分布的协方差(8x8维)
|
256 |
+
measurements : ndarray
|
257 |
+
An Nx4 dimensional matrix of N measurements, each in
|
258 |
+
format (x, y, a, h) where (x, y) is the bounding box center
|
259 |
+
position, a the aspect ratio, and h the height.
|
260 |
+
N 个测量的 N×4维矩阵,每个矩阵的格式为(x,y,a,h),其中(x,y)是边界框中心位置,宽高比和h高度。
|
261 |
+
only_position : Optional[bool]
|
262 |
+
If True, distance computation is done with respect to the bounding
|
263 |
+
box center position only.
|
264 |
+
如果为True,则只计算盒子中心位置
|
265 |
+
|
266 |
+
Returns
|
267 |
+
-------
|
268 |
+
ndarray
|
269 |
+
Returns an array of length N, where the i-th element contains the
|
270 |
+
squared Mahalanobis distance between (mean, covariance) and
|
271 |
+
`measurements[i]`.
|
272 |
+
返回一个长度为N的数组,其中第i个元素包含(mean,covariance)和measurements [i]之间的平方Mahalanobis距离
|
273 |
+
|
274 |
+
"""
|
275 |
+
mean, covariance = self.project(mean, covariance)
|
276 |
+
if only_position:
|
277 |
+
mean, covariance = mean[:2], covariance[:2, :2]
|
278 |
+
measurements = measurements[:, :2]
|
279 |
+
|
280 |
+
cholesky_factor = np.linalg.cholesky(covariance)
|
281 |
+
d = measurements - mean
|
282 |
+
z = scipy.linalg.solve_triangular(
|
283 |
+
cholesky_factor, d.T, lower=True, check_finite=False,
|
284 |
+
overwrite_b=True)
|
285 |
+
squared_maha = np.sum(z * z, axis=0)
|
286 |
+
return squared_maha
|
deep_sort/deep_sort/sort/linear_assignment.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
from __future__ import absolute_import
|
3 |
+
import numpy as np
|
4 |
+
# The linear sum assignment problem is also known as minimum weight matching in bipartite graphs.
|
5 |
+
from scipy.optimize import linear_sum_assignment as linear_assignment
|
6 |
+
from . import kalman_filter
|
7 |
+
|
8 |
+
|
9 |
+
INFTY_COST = 1e+5
|
10 |
+
|
11 |
+
# min_cost_matching 使用匈牙利算法解决线性分配问题。
|
12 |
+
# 传入 门控余弦距离成本 或 iou cost
|
13 |
+
def min_cost_matching(
|
14 |
+
distance_metric, max_distance, tracks, detections, track_indices=None,
|
15 |
+
detection_indices=None):
|
16 |
+
"""Solve linear assignment problem.
|
17 |
+
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
21 |
+
The distance metric is given a list of tracks and detections as well as
|
22 |
+
a list of N track indices and M detection indices. The metric should
|
23 |
+
return the NxM dimensional cost matrix, where element (i, j) is the
|
24 |
+
association cost between the i-th track in the given track indices and
|
25 |
+
the j-th detection in the given detection_indices.
|
26 |
+
max_distance : float
|
27 |
+
Gating threshold. Associations with cost larger than this value are
|
28 |
+
disregarded.
|
29 |
+
tracks : List[track.Track]
|
30 |
+
A list of predicted tracks at the current time step.
|
31 |
+
detections : List[detection.Detection]
|
32 |
+
A list of detections at the current time step.
|
33 |
+
track_indices : List[int]
|
34 |
+
List of track indices that maps rows in `cost_matrix` to tracks in
|
35 |
+
`tracks` (see description above).
|
36 |
+
detection_indices : List[int]
|
37 |
+
List of detection indices that maps columns in `cost_matrix` to
|
38 |
+
detections in `detections` (see description above).
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
(List[(int, int)], List[int], List[int])
|
43 |
+
Returns a tuple with the following three entries:
|
44 |
+
* A list of matched track and detection indices.
|
45 |
+
* A list of unmatched track indices.
|
46 |
+
* A list of unmatched detection indices.
|
47 |
+
|
48 |
+
"""
|
49 |
+
if track_indices is None:
|
50 |
+
track_indices = np.arange(len(tracks))
|
51 |
+
if detection_indices is None:
|
52 |
+
detection_indices = np.arange(len(detections))
|
53 |
+
|
54 |
+
if len(detection_indices) == 0 or len(track_indices) == 0:
|
55 |
+
return [], track_indices, detection_indices # Nothing to match.
|
56 |
+
|
57 |
+
# 计算成本矩阵
|
58 |
+
cost_matrix = distance_metric(
|
59 |
+
tracks, detections, track_indices, detection_indices)
|
60 |
+
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
|
61 |
+
|
62 |
+
# 执行匈牙利算法,得到指派成功的索引对,行索引为tracks的索引,列索引为detections的索引
|
63 |
+
row_indices, col_indices = linear_assignment(cost_matrix)
|
64 |
+
|
65 |
+
matches, unmatched_tracks, unmatched_detections = [], [], []
|
66 |
+
# 找出未匹配的detections
|
67 |
+
for col, detection_idx in enumerate(detection_indices):
|
68 |
+
if col not in col_indices:
|
69 |
+
unmatched_detections.append(detection_idx)
|
70 |
+
# 找出未匹配的tracks
|
71 |
+
for row, track_idx in enumerate(track_indices):
|
72 |
+
if row not in row_indices:
|
73 |
+
unmatched_tracks.append(track_idx)
|
74 |
+
# 遍历匹配的(track, detection)索引对
|
75 |
+
for row, col in zip(row_indices, col_indices):
|
76 |
+
track_idx = track_indices[row]
|
77 |
+
detection_idx = detection_indices[col]
|
78 |
+
# 如果相应的cost大于阈值max_distance,也视为未匹配成功
|
79 |
+
if cost_matrix[row, col] > max_distance:
|
80 |
+
unmatched_tracks.append(track_idx)
|
81 |
+
unmatched_detections.append(detection_idx)
|
82 |
+
else:
|
83 |
+
matches.append((track_idx, detection_idx))
|
84 |
+
return matches, unmatched_tracks, unmatched_detections
|
85 |
+
|
86 |
+
|
87 |
+
def matching_cascade(
|
88 |
+
distance_metric, max_distance, cascade_depth, tracks, detections,
|
89 |
+
track_indices=None, detection_indices=None):
|
90 |
+
"""Run matching cascade.
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
95 |
+
The distance metric is given a list of tracks and detections as well as
|
96 |
+
a list of N track indices and M detection indices. The metric should
|
97 |
+
return the NxM dimensional cost matrix, where element (i, j) is the
|
98 |
+
association cost between the i-th track in the given track indices and
|
99 |
+
the j-th detection in the given detection indices.
|
100 |
+
距离度量:
|
101 |
+
输入:一个轨迹和检测列表,以及一个N个轨迹索引和M个检测索引的列表。
|
102 |
+
返回:NxM维的代价矩阵,其中元素(i,j)是给定轨迹索引中第i个轨迹与
|
103 |
+
给定检测索引中第j个检测之间的关联成本。
|
104 |
+
max_distance : float
|
105 |
+
Gating threshold. Associations with cost larger than this value are
|
106 |
+
disregarded.
|
107 |
+
门控阈值。成本大于此值的关联将被忽略。
|
108 |
+
cascade_depth: int
|
109 |
+
The cascade depth, should be se to the maximum track age.
|
110 |
+
级联深度应设置为最大轨迹寿命。
|
111 |
+
tracks : List[track.Track]
|
112 |
+
A list of predicted tracks at the current time step.
|
113 |
+
当前时间步的预测轨迹列表。
|
114 |
+
detections : List[detection.Detection]
|
115 |
+
A list of detections at the current time step.
|
116 |
+
当前时间步的检测列表。
|
117 |
+
track_indices : Optional[List[int]]
|
118 |
+
List of track indices that maps rows in `cost_matrix` to tracks in
|
119 |
+
`tracks` (see description above). Defaults to all tracks.
|
120 |
+
轨迹索引列表,用于将 cost_matrix中的行映射到tracks的
|
121 |
+
轨迹(请参见上面的说明)。 默认为所有轨迹。
|
122 |
+
detection_indices : Optional[List[int]]
|
123 |
+
List of detection indices that maps columns in `cost_matrix` to
|
124 |
+
detections in `detections` (see description above). Defaults to all
|
125 |
+
detections.
|
126 |
+
将 cost_matrix中的列映射到的检测索引列表
|
127 |
+
detections中的检测(请参见上面的说明)。 默认为全部检测。
|
128 |
+
|
129 |
+
Returns
|
130 |
+
-------
|
131 |
+
(List[(int, int)], List[int], List[int])
|
132 |
+
Returns a tuple with the following three entries:
|
133 |
+
* A list of matched track and detection indices.
|
134 |
+
* A list of unmatched track indices.
|
135 |
+
* A list of unmatched detection indices.
|
136 |
+
|
137 |
+
返回包含以下三个条目的元组:
|
138 |
+
|
139 |
+
匹配的跟踪和检测的索引列表,
|
140 |
+
不匹配的轨迹索引的列表,
|
141 |
+
未匹配的检测索引的列表。
|
142 |
+
|
143 |
+
"""
|
144 |
+
|
145 |
+
# 分配track_indices和detection_indices两个列表
|
146 |
+
if track_indices is None:
|
147 |
+
track_indices = list(range(len(tracks)))
|
148 |
+
if detection_indices is None:
|
149 |
+
detection_indices = list(range(len(detections)))
|
150 |
+
|
151 |
+
# 初始化匹配集matches M ← ∅
|
152 |
+
# 未匹配检测集unmatched_detections U ← D
|
153 |
+
unmatched_detections = detection_indices
|
154 |
+
matches = []
|
155 |
+
# 由小到大依次对每个level的tracks做匹配
|
156 |
+
for level in range(cascade_depth):
|
157 |
+
# 如果没有detections,退出循环
|
158 |
+
if len(unmatched_detections) == 0: # No detections left
|
159 |
+
break
|
160 |
+
|
161 |
+
# 当前level的所有tracks索引
|
162 |
+
# 步骤6:Select tracks by age
|
163 |
+
track_indices_l = [
|
164 |
+
k for k in track_indices
|
165 |
+
if tracks[k].time_since_update == 1 + level
|
166 |
+
]
|
167 |
+
# 如果当前level没有track,继续
|
168 |
+
if len(track_indices_l) == 0: # Nothing to match at this level
|
169 |
+
continue
|
170 |
+
|
171 |
+
# 步骤7:调用min_cost_matching函数进行匹配
|
172 |
+
matches_l, _, unmatched_detections = \
|
173 |
+
min_cost_matching(
|
174 |
+
distance_metric, max_distance, tracks, detections,
|
175 |
+
track_indices_l, unmatched_detections)
|
176 |
+
matches += matches_l # 步骤8
|
177 |
+
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) # 步骤9
|
178 |
+
return matches, unmatched_tracks, unmatched_detections
|
179 |
+
|
180 |
+
'''
|
181 |
+
门控成本矩阵:通过计算卡尔曼滤波的状态分布和测量值之间的距离对成本矩阵进行限制,
|
182 |
+
成本矩阵中的距离是track和detection之间的外观相似度。
|
183 |
+
如果一个轨迹要去匹配两个外观特征非常相似的 detection,很容易出错;
|
184 |
+
分别让两个detection计算与这个轨迹的马氏距离,并使用一个阈值gating_threshold进行限制,
|
185 |
+
就可以将马氏距离较远的那个detection区分开,从而减少错误的匹配。
|
186 |
+
'''
|
187 |
+
def gate_cost_matrix(
|
188 |
+
kf, cost_matrix, tracks, detections, track_indices, detection_indices,
|
189 |
+
gated_cost=INFTY_COST, only_position=False):
|
190 |
+
"""Invalidate infeasible entries in cost matrix based on the state
|
191 |
+
distributions obtained by Kalman filtering.
|
192 |
+
|
193 |
+
Parameters
|
194 |
+
----------
|
195 |
+
kf : The Kalman filter.
|
196 |
+
cost_matrix : ndarray
|
197 |
+
The NxM dimensional cost matrix, where N is the number of track indices
|
198 |
+
and M is the number of detection indices, such that entry (i, j) is the
|
199 |
+
association cost between `tracks[track_indices[i]]` and
|
200 |
+
`detections[detection_indices[j]]`.
|
201 |
+
tracks : List[track.Track]
|
202 |
+
A list of predicted tracks at the current time step.
|
203 |
+
detections : List[detection.Detection]
|
204 |
+
A list of detections at the current time step.
|
205 |
+
track_indices : List[int]
|
206 |
+
List of track indices that maps rows in `cost_matrix` to tracks in
|
207 |
+
`tracks` (see description above).
|
208 |
+
detection_indices : List[int]
|
209 |
+
List of detection indices that maps columns in `cost_matrix` to
|
210 |
+
detections in `detections` (see description above).
|
211 |
+
gated_cost : Optional[float]
|
212 |
+
Entries in the cost matrix corresponding to infeasible associations are
|
213 |
+
set this value. Defaults to a very large value.
|
214 |
+
代价矩阵中与不可行关联相对应的条目设置此值。 默认为一个很大的值。
|
215 |
+
only_position : Optional[bool]
|
216 |
+
If True, only the x, y position of the state distribution is considered
|
217 |
+
during gating. Defaults to False.
|
218 |
+
如果为True,则在门控期间仅考虑状态分布的x,y位置。默认为False。
|
219 |
+
|
220 |
+
Returns
|
221 |
+
-------
|
222 |
+
ndarray
|
223 |
+
Returns the modified cost matrix.
|
224 |
+
|
225 |
+
"""
|
226 |
+
# 根据通过卡尔曼滤波获得的状态分布,使成本矩阵中的不可行条目无效。
|
227 |
+
gating_dim = 2 if only_position else 4 # 测量空间维度
|
228 |
+
# 马氏距离通过测算检测与平均轨迹位置的距离超过多少标准差来考虑状态估计的不确定性。
|
229 |
+
# 通过从逆chi^2分布计算95%置信区间的阈值,排除可能性小的关联。
|
230 |
+
# 四维测量空间对应的马氏阈值为9.4877
|
231 |
+
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
232 |
+
measurements = np.asarray(
|
233 |
+
[detections[i].to_xyah() for i in detection_indices])
|
234 |
+
for row, track_idx in enumerate(track_indices):
|
235 |
+
track = tracks[track_idx]
|
236 |
+
#KalmanFilter.gating_distance 计算状态分布和测量之间的选通距离
|
237 |
+
gating_distance = kf.gating_distance(
|
238 |
+
track.mean, track.covariance, measurements, only_position)
|
239 |
+
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
|
240 |
+
return cost_matrix
|
deep_sort/deep_sort/sort/nn_matching.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def _pdist(a, b):
|
6 |
+
"""Compute pair-wise squared distance between points in `a` and `b`.
|
7 |
+
|
8 |
+
Parameters
|
9 |
+
----------
|
10 |
+
a : array_like
|
11 |
+
An NxM matrix of N samples of dimensionality M.
|
12 |
+
b : array_like
|
13 |
+
An LxM matrix of L samples of dimensionality M.
|
14 |
+
|
15 |
+
Returns
|
16 |
+
-------
|
17 |
+
ndarray
|
18 |
+
Returns a matrix of size len(a), len(b) such that element (i, j)
|
19 |
+
contains the squared distance between `a[i]` and `b[j]`.
|
20 |
+
|
21 |
+
|
22 |
+
用于计算成对点之间的平方距离
|
23 |
+
a :NxM 矩阵,代表 N 个样本,每个样本 M 个数值
|
24 |
+
b :LxM 矩阵,代表 L 个样本,每个样本有 M 个数值
|
25 |
+
返回的是 NxL 的矩阵,比如 dist[i][j] 代表 a[i] 和 b[j] 之间的平方和距离
|
26 |
+
参考:https://blog.csdn.net/frankzd/article/details/80251042
|
27 |
+
|
28 |
+
"""
|
29 |
+
a, b = np.asarray(a), np.asarray(b)
|
30 |
+
if len(a) == 0 or len(b) == 0:
|
31 |
+
return np.zeros((len(a), len(b)))
|
32 |
+
a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
|
33 |
+
r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
|
34 |
+
r2 = np.clip(r2, 0., float(np.inf))
|
35 |
+
return r2
|
36 |
+
|
37 |
+
|
38 |
+
def _cosine_distance(a, b, data_is_normalized=False):
|
39 |
+
"""Compute pair-wise cosine distance between points in `a` and `b`.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
a : array_like
|
44 |
+
An NxM matrix of N samples of dimensionality M.
|
45 |
+
b : array_like
|
46 |
+
An LxM matrix of L samples of dimensionality M.
|
47 |
+
data_is_normalized : Optional[bool]
|
48 |
+
If True, assumes rows in a and b are unit length vectors.
|
49 |
+
Otherwise, a and b are explicitly normalized to lenght 1.
|
50 |
+
|
51 |
+
Returns
|
52 |
+
-------
|
53 |
+
ndarray
|
54 |
+
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
55 |
+
contains the squared distance between `a[i]` and `b[j]`.
|
56 |
+
|
57 |
+
用于计算成对点之间的余弦距离
|
58 |
+
a :NxM 矩阵,代表 N 个样本,每个样本 M 个数值
|
59 |
+
b :LxM 矩阵,代表 L 个样本,每个样本有 M 个数值
|
60 |
+
返回的是 NxL 的矩阵,比如 c[i][j] 代表 a[i] 和 b[j] 之间的余弦距离
|
61 |
+
参考:
|
62 |
+
https://blog.csdn.net/u013749540/article/details/51813922
|
63 |
+
|
64 |
+
|
65 |
+
"""
|
66 |
+
if not data_is_normalized:
|
67 |
+
# np.linalg.norm 求向量的范式,默认是 L2 范式
|
68 |
+
a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
|
69 |
+
b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
|
70 |
+
return 1. - np.dot(a, b.T) # 余弦距离 = 1 - 余弦相似度
|
71 |
+
|
72 |
+
|
73 |
+
def _nn_euclidean_distance(x, y):
|
74 |
+
""" Helper function for nearest neighbor distance metric (Euclidean).
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
x : ndarray
|
79 |
+
A matrix of N row-vectors (sample points).
|
80 |
+
y : ndarray
|
81 |
+
A matrix of M row-vectors (query points).
|
82 |
+
|
83 |
+
Returns
|
84 |
+
-------
|
85 |
+
ndarray
|
86 |
+
A vector of length M that contains for each entry in `y` the
|
87 |
+
smallest Euclidean distance to a sample in `x`.
|
88 |
+
|
89 |
+
"""
|
90 |
+
distances = _pdist(x, y)
|
91 |
+
return np.maximum(0.0, distances.min(axis=0))
|
92 |
+
|
93 |
+
|
94 |
+
def _nn_cosine_distance(x, y):
|
95 |
+
""" Helper function for nearest neighbor distance metric (cosine).
|
96 |
+
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
x : ndarray
|
100 |
+
A matrix of N row-vectors (sample points).
|
101 |
+
y : ndarray
|
102 |
+
A matrix of M row-vectors (query points).
|
103 |
+
|
104 |
+
Returns
|
105 |
+
-------
|
106 |
+
ndarray
|
107 |
+
A vector of length M that contains for each entry in `y` the
|
108 |
+
smallest cosine distance to a sample in `x`.
|
109 |
+
|
110 |
+
"""
|
111 |
+
distances = _cosine_distance(x, y)
|
112 |
+
return distances.min(axis=0)
|
113 |
+
|
114 |
+
|
115 |
+
class NearestNeighborDistanceMetric(object):
|
116 |
+
"""
|
117 |
+
A nearest neighbor distance metric that, for each target, returns
|
118 |
+
the closest distance to any sample that has been observed so far.
|
119 |
+
|
120 |
+
对于每个目标,返回最近邻居的距离度量, 即与到目前为止已观察到的任何样本的最接近距离。
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
metric : str
|
125 |
+
Either "euclidean" or "cosine".
|
126 |
+
matching_threshold: float
|
127 |
+
The matching threshold. Samples with larger distance are considered an
|
128 |
+
invalid match.
|
129 |
+
匹配阈值。 距离较大的样本对被认为是无效的匹配。
|
130 |
+
budget : Optional[int]
|
131 |
+
If not None, fix samples per class to at most this number. Removes
|
132 |
+
the oldest samples when the budget is reached.
|
133 |
+
如果不是None,则将每个类别的样本最多固定为该数字。
|
134 |
+
删除达到budget时最古老的样本。
|
135 |
+
|
136 |
+
Attributes
|
137 |
+
----------
|
138 |
+
samples : Dict[int -> List[ndarray]]
|
139 |
+
A dictionary that maps from target identities to the list of samples
|
140 |
+
that have been observed so far.
|
141 |
+
一个从目标ID映射到到目前为止已经观察到的样本列表的字典
|
142 |
+
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, metric, matching_threshold, budget=None):
|
146 |
+
|
147 |
+
|
148 |
+
if metric == "euclidean":
|
149 |
+
self._metric = _nn_euclidean_distance # 欧式距离
|
150 |
+
elif metric == "cosine":
|
151 |
+
self._metric = _nn_cosine_distance # 余弦距离
|
152 |
+
else:
|
153 |
+
raise ValueError(
|
154 |
+
"Invalid metric; must be either 'euclidean' or 'cosine'")
|
155 |
+
self.matching_threshold = matching_threshold
|
156 |
+
self.budget = budget # budge用于控制 feature 的数目
|
157 |
+
self.samples = {}
|
158 |
+
|
159 |
+
def partial_fit(self, features, targets, active_targets):
|
160 |
+
"""Update the distance metric with new data.
|
161 |
+
用新的数据更新测量距离
|
162 |
+
|
163 |
+
Parameters
|
164 |
+
----------
|
165 |
+
features : ndarray
|
166 |
+
An NxM matrix of N features of dimensionality M.
|
167 |
+
targets : ndarray
|
168 |
+
An integer array of associated target identities.
|
169 |
+
active_targets : List[int]
|
170 |
+
A list of targets that are currently present in the scene.
|
171 |
+
传入特征列表及其对应id,partial_fit构造一个活跃目标的特征字典。
|
172 |
+
|
173 |
+
"""
|
174 |
+
for feature, target in zip(features, targets):
|
175 |
+
# 对应目标下添加新的feature,更新feature集合
|
176 |
+
# samples字典 d: feature list}
|
177 |
+
self.samples.setdefault(target, []).append(feature)
|
178 |
+
if self.budget is not None:
|
179 |
+
# 只考虑budget个目标,超过直接忽略
|
180 |
+
self.samples[target] = self.samples[target][-self.budget:]
|
181 |
+
|
182 |
+
# 筛选激活的目标;samples是一个字典{id->feature list}
|
183 |
+
self.samples = {k: self.samples[k] for k in active_targets}
|
184 |
+
|
185 |
+
def distance(self, features, targets):
|
186 |
+
"""Compute distance between features and targets.
|
187 |
+
|
188 |
+
Parameters
|
189 |
+
----------
|
190 |
+
features : ndarray
|
191 |
+
An NxM matrix of N features of dimensionality M.
|
192 |
+
targets : List[int]
|
193 |
+
A list of targets to match the given `features` against.
|
194 |
+
|
195 |
+
Returns
|
196 |
+
-------
|
197 |
+
ndarray
|
198 |
+
Returns a cost matrix of shape len(targets), len(features), where
|
199 |
+
element (i, j) contains the closest squared distance between
|
200 |
+
`targets[i]` and `features[j]`.
|
201 |
+
|
202 |
+
计算features和targets之间的距离,返回一个成本矩阵(代价矩阵)
|
203 |
+
"""
|
204 |
+
cost_matrix = np.zeros((len(targets), len(features)))
|
205 |
+
for i, target in enumerate(targets):
|
206 |
+
cost_matrix[i, :] = self._metric(self.samples[target], features)
|
207 |
+
return cost_matrix
|
deep_sort/deep_sort/sort/preprocessing.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def non_max_suppression(boxes, max_bbox_overlap, scores=None):
|
7 |
+
"""Suppress overlapping detections.
|
8 |
+
|
9 |
+
Original code from [1]_ has been adapted to include confidence score.
|
10 |
+
|
11 |
+
.. [1] http://www.pyimagesearch.com/2015/02/16/
|
12 |
+
faster-non-maximum-suppression-python/
|
13 |
+
|
14 |
+
Examples
|
15 |
+
--------
|
16 |
+
|
17 |
+
>>> boxes = [d.roi for d in detections]
|
18 |
+
>>> scores = [d.confidence for d in detections]
|
19 |
+
>>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
|
20 |
+
>>> detections = [detections[i] for i in indices]
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
boxes : ndarray
|
25 |
+
Array of ROIs (x, y, width, height).
|
26 |
+
max_bbox_overlap : float
|
27 |
+
ROIs that overlap more than this values are suppressed.
|
28 |
+
scores : Optional[array_like]
|
29 |
+
Detector confidence score.
|
30 |
+
|
31 |
+
Returns
|
32 |
+
-------
|
33 |
+
List[int]
|
34 |
+
Returns indices of detections that have survived non-maxima suppression.
|
35 |
+
|
36 |
+
"""
|
37 |
+
if len(boxes) == 0:
|
38 |
+
return []
|
39 |
+
|
40 |
+
boxes = boxes.astype(np.float32)
|
41 |
+
pick = []
|
42 |
+
|
43 |
+
x1 = boxes[:, 0]
|
44 |
+
y1 = boxes[:, 1]
|
45 |
+
x2 = boxes[:, 2] + boxes[:, 0]
|
46 |
+
y2 = boxes[:, 3] + boxes[:, 1]
|
47 |
+
|
48 |
+
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
49 |
+
if scores is not None:
|
50 |
+
idxs = np.argsort(scores)
|
51 |
+
else:
|
52 |
+
idxs = np.argsort(y2)
|
53 |
+
|
54 |
+
while len(idxs) > 0:
|
55 |
+
last = len(idxs) - 1
|
56 |
+
i = idxs[last]
|
57 |
+
pick.append(i)
|
58 |
+
|
59 |
+
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
60 |
+
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
61 |
+
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
62 |
+
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
63 |
+
|
64 |
+
w = np.maximum(0, xx2 - xx1 + 1)
|
65 |
+
h = np.maximum(0, yy2 - yy1 + 1)
|
66 |
+
|
67 |
+
overlap = (w * h) / area[idxs[:last]] # IOU
|
68 |
+
|
69 |
+
idxs = np.delete(
|
70 |
+
idxs, np.concatenate(
|
71 |
+
([last], np.where(overlap > max_bbox_overlap)[0])))
|
72 |
+
|
73 |
+
return pick
|
deep_sort/deep_sort/sort/track.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
|
3 |
+
|
4 |
+
class TrackState:
|
5 |
+
"""
|
6 |
+
Enumeration type for the single target track state. Newly created tracks are
|
7 |
+
classified as `tentative` until enough evidence has been collected. Then,
|
8 |
+
the track state is changed to `confirmed`. Tracks that are no longer alive
|
9 |
+
are classified as `deleted` to mark them for removal from the set of active
|
10 |
+
tracks.
|
11 |
+
|
12 |
+
单个目标track状态的枚举类型。
|
13 |
+
新创建的track分类为“Tentative”,直到收集到足够的证据为止。
|
14 |
+
然后,跟踪状态更改为“Confirmed”。
|
15 |
+
不再活跃的tracks被归类为“Deleted”,以将其标记为从有效集中删除。
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
Tentative = 1
|
20 |
+
Confirmed = 2
|
21 |
+
Deleted = 3
|
22 |
+
|
23 |
+
|
24 |
+
class Track:
|
25 |
+
"""
|
26 |
+
A single target track with state space `(x, y, a, h)` and associated
|
27 |
+
velocities, where `(x, y)` is the center of the bounding box, `a` is the
|
28 |
+
aspect ratio and `h` is the height.
|
29 |
+
|
30 |
+
具有状态空间(x,y,a,h)并关联速度的单个目标轨迹(track),
|
31 |
+
其中(x,y)是边界框的中心,a是宽高比,h是高度。
|
32 |
+
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
mean : ndarray
|
36 |
+
Mean vector of the initial state distribution.
|
37 |
+
初始状态分布的均值向量
|
38 |
+
covariance : ndarray
|
39 |
+
Covariance matrix of the initial state distribution.
|
40 |
+
初始状态分布的协方差矩阵
|
41 |
+
track_id : int
|
42 |
+
A unique track identifier.
|
43 |
+
唯一的track标识符
|
44 |
+
n_init : int
|
45 |
+
Number of consecutive detections before the track is confirmed. The
|
46 |
+
track state is set to `Deleted` if a miss occurs within the first
|
47 |
+
`n_init` frames.
|
48 |
+
确认track之前的连续检测次数。 在第一个n_init帧中
|
49 |
+
第一个未命中的情况下将跟踪状态设置为“Deleted”
|
50 |
+
max_age : int
|
51 |
+
The maximum number of consecutive misses before the track state is
|
52 |
+
set to `Deleted`.
|
53 |
+
跟踪状态设置为Deleted之前的最大连续未命中数;代表一个track的存活期限
|
54 |
+
|
55 |
+
feature : Optional[ndarray]
|
56 |
+
Feature vector of the detection this track originates from. If not None,
|
57 |
+
this feature is added to the `features` cache.
|
58 |
+
此track所源自的检测的特征向量。 如果不是None,此feature已添加到feature缓存中。
|
59 |
+
|
60 |
+
Attributes
|
61 |
+
----------
|
62 |
+
mean : ndarray
|
63 |
+
Mean vector of the initial state distribution.
|
64 |
+
初始状态分布的均值向量
|
65 |
+
covariance : ndarray
|
66 |
+
Covariance matrix of the initial state distribution.
|
67 |
+
初始状态分布的协方差矩阵
|
68 |
+
track_id : int
|
69 |
+
A unique track identifier.
|
70 |
+
hits : int
|
71 |
+
Total number of measurement updates.
|
72 |
+
测量更新总数
|
73 |
+
age : int
|
74 |
+
Total number of frames since first occurence.
|
75 |
+
自第一次出现以来的总帧数
|
76 |
+
time_since_update : int
|
77 |
+
Total number of frames since last measurement update.
|
78 |
+
自上次测量更新以来的总帧数
|
79 |
+
state : TrackState
|
80 |
+
The current track state.
|
81 |
+
features : List[ndarray]
|
82 |
+
A cache of features. On each measurement update, the associated feature
|
83 |
+
vector is added to this list.
|
84 |
+
feature缓存。每次测量更新时,相关feature向量添加到此列表中
|
85 |
+
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, mean, covariance, track_id, n_init, max_age,
|
89 |
+
feature=None):
|
90 |
+
self.mean = mean
|
91 |
+
self.covariance = covariance
|
92 |
+
self.track_id = track_id
|
93 |
+
# hits代表匹配上了多少次,匹配次数超过n_init,设置Confirmed状态
|
94 |
+
# hits每次调用update函数的时候+1
|
95 |
+
self.hits = 1
|
96 |
+
self.age = 1 # 和time_since_update功能重复
|
97 |
+
# 每次调用predict函数的时候就会+1; 每次调用update函数的时候就会设置为0
|
98 |
+
self.time_since_update = 0
|
99 |
+
|
100 |
+
self.state = TrackState.Tentative # 初始化一个Track的时设置Tentative状态
|
101 |
+
# 每个track对应多个features, 每次更新都会将最新的feature添加到列表中
|
102 |
+
self.features = []
|
103 |
+
if feature is not None:
|
104 |
+
self.features.append(feature)
|
105 |
+
|
106 |
+
self._n_init = n_init
|
107 |
+
self._max_age = max_age
|
108 |
+
|
109 |
+
def to_tlwh(self):
|
110 |
+
"""Get current position in bounding box format `(top left x, top left y,
|
111 |
+
width, height)`.
|
112 |
+
|
113 |
+
Returns
|
114 |
+
-------
|
115 |
+
ndarray
|
116 |
+
The bounding box.
|
117 |
+
|
118 |
+
"""
|
119 |
+
ret = self.mean[:4].copy()
|
120 |
+
ret[2] *= ret[3]
|
121 |
+
ret[:2] -= ret[2:] / 2
|
122 |
+
return ret
|
123 |
+
|
124 |
+
def to_tlbr(self):
|
125 |
+
"""Get current position in bounding box format `(min x, miny, max x,
|
126 |
+
max y)`.
|
127 |
+
|
128 |
+
Returns
|
129 |
+
-------
|
130 |
+
ndarray
|
131 |
+
The bounding box.
|
132 |
+
|
133 |
+
"""
|
134 |
+
ret = self.to_tlwh()
|
135 |
+
ret[2:] = ret[:2] + ret[2:]
|
136 |
+
return ret
|
137 |
+
|
138 |
+
def predict(self, kf):
|
139 |
+
"""Propagate the state distribution to the current time step using a
|
140 |
+
Kalman filter prediction step.
|
141 |
+
使用卡尔曼滤波器预测步骤将状态分布传播到当前时间步
|
142 |
+
|
143 |
+
Parameters
|
144 |
+
----------
|
145 |
+
kf : kalman_filter.KalmanFilter
|
146 |
+
The Kalman filter.
|
147 |
+
|
148 |
+
"""
|
149 |
+
self.mean, self.covariance = kf.predict(self.mean, self.covariance)
|
150 |
+
self.age += 1
|
151 |
+
self.time_since_update += 1
|
152 |
+
|
153 |
+
def update(self, kf, detection):
|
154 |
+
"""Perform Kalman filter measurement update step and update the feature
|
155 |
+
cache.
|
156 |
+
执行卡尔曼滤波器测量更新步骤并更新feature缓存
|
157 |
+
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
kf : kalman_filter.KalmanFilter
|
161 |
+
The Kalman filter.
|
162 |
+
detection : Detection
|
163 |
+
The associated detection.
|
164 |
+
|
165 |
+
"""
|
166 |
+
self.mean, self.covariance = kf.update(
|
167 |
+
self.mean, self.covariance, detection.to_xyah())
|
168 |
+
self.features.append(detection.feature)
|
169 |
+
|
170 |
+
self.hits += 1
|
171 |
+
self.time_since_update = 0
|
172 |
+
# hits代表匹配上了多少次,匹配次数超过n_init,设置Confirmed状态
|
173 |
+
# 连续匹配上n_init帧的时候,转变为确定态
|
174 |
+
if self.state == TrackState.Tentative and self.hits >= self._n_init:
|
175 |
+
self.state = TrackState.Confirmed
|
176 |
+
|
177 |
+
def mark_missed(self):
|
178 |
+
"""Mark this track as missed (no association at the current time step).
|
179 |
+
"""
|
180 |
+
# 如果在处于Tentative态的情况下没有匹配上任何detection,转变为删除态。
|
181 |
+
if self.state == TrackState.Tentative:
|
182 |
+
self.state = TrackState.Deleted
|
183 |
+
elif self.time_since_update > self._max_age:
|
184 |
+
# 如果time_since_update超过max_age,设置Deleted状态
|
185 |
+
# 即失配连续达到max_age次数的时候,转变为删除态
|
186 |
+
self.state = TrackState.Deleted
|
187 |
+
|
188 |
+
def is_tentative(self):
|
189 |
+
"""Returns True if this track is tentative (unconfirmed).
|
190 |
+
"""
|
191 |
+
return self.state == TrackState.Tentative
|
192 |
+
|
193 |
+
def is_confirmed(self):
|
194 |
+
"""Returns True if this track is confirmed."""
|
195 |
+
return self.state == TrackState.Confirmed
|
196 |
+
|
197 |
+
def is_deleted(self):
|
198 |
+
"""Returns True if this track is dead and should be deleted."""
|
199 |
+
return self.state == TrackState.Deleted
|
deep_sort/deep_sort/sort/tracker.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
from __future__ import absolute_import
|
3 |
+
import numpy as np
|
4 |
+
from . import kalman_filter
|
5 |
+
from . import linear_assignment
|
6 |
+
from . import iou_matching
|
7 |
+
from .track import Track
|
8 |
+
|
9 |
+
|
10 |
+
class Tracker:
|
11 |
+
"""
|
12 |
+
This is the multi-target tracker.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
metric : nn_matching.NearestNeighborDistanceMetric
|
17 |
+
A distance metric for measurement-to-track association.
|
18 |
+
max_age : int
|
19 |
+
Maximum number of missed misses before a track is deleted.
|
20 |
+
n_init : int
|
21 |
+
Number of consecutive detections before the track is confirmed. The
|
22 |
+
track state is set to `Deleted` if a miss occurs within the first
|
23 |
+
`n_init` frames.
|
24 |
+
|
25 |
+
Attributes
|
26 |
+
----------
|
27 |
+
metric : nn_matching.NearestNeighborDistanceMetric
|
28 |
+
The distance metric used for measurement to track association.
|
29 |
+
测量与轨迹关联的距离度量
|
30 |
+
max_age : int
|
31 |
+
Maximum number of missed misses before a track is deleted.
|
32 |
+
删除轨迹前的最大未命中数
|
33 |
+
n_init : int
|
34 |
+
Number of frames that a track remains in initialization phase.
|
35 |
+
确认轨迹前的连续检测次数。如果前n_init帧内发生未命中,则将轨迹状态设置为Deleted
|
36 |
+
kf : kalman_filter.KalmanFilter
|
37 |
+
A Kalman filter to filter target trajectories in image space.
|
38 |
+
tracks : List[Track]
|
39 |
+
The list of active tracks at the current time step.
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
|
44 |
+
self.metric = metric
|
45 |
+
self.max_iou_distance = max_iou_distance
|
46 |
+
self.max_age = max_age
|
47 |
+
self.n_init = n_init
|
48 |
+
|
49 |
+
self.kf = kalman_filter.KalmanFilter() # 实例化卡尔曼滤波器
|
50 |
+
self.tracks = [] # 保存一个轨迹列表,用于保存一系列轨迹
|
51 |
+
self._next_id = 1 # 下一个分配的轨迹id
|
52 |
+
|
53 |
+
def predict(self):
|
54 |
+
"""Propagate track state distributions one time step forward.
|
55 |
+
将跟踪状态分布向前传播一步
|
56 |
+
|
57 |
+
This function should be called once every time step, before `update`.
|
58 |
+
"""
|
59 |
+
for track in self.tracks:
|
60 |
+
track.predict(self.kf)
|
61 |
+
|
62 |
+
def update(self, detections):
|
63 |
+
"""Perform measurement update and track management.
|
64 |
+
执行测量更新和轨迹管理
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
detections : List[deep_sort.detection.Detection]
|
69 |
+
A list of detections at the current time step.
|
70 |
+
|
71 |
+
"""
|
72 |
+
# Run matching cascade.
|
73 |
+
matches, unmatched_tracks, unmatched_detections = \
|
74 |
+
self._match(detections)
|
75 |
+
|
76 |
+
# Update track set.
|
77 |
+
|
78 |
+
# 1. 针对匹配上的结果
|
79 |
+
for track_idx, detection_idx in matches:
|
80 |
+
# 更新tracks中相应的detection
|
81 |
+
self.tracks[track_idx].update(
|
82 |
+
self.kf, detections[detection_idx])
|
83 |
+
|
84 |
+
# 2. 针对未匹配的track, 调用mark_missed进行标记
|
85 |
+
# track失配时,若Tantative则删除;若update时间很久也删除
|
86 |
+
for track_idx in unmatched_tracks:
|
87 |
+
self.tracks[track_idx].mark_missed()
|
88 |
+
|
89 |
+
# 3. 针对未匹配的detection, detection失配,进行初始化
|
90 |
+
for detection_idx in unmatched_detections:
|
91 |
+
self._initiate_track(detections[detection_idx])
|
92 |
+
|
93 |
+
# 得到最新的tracks列表,保存的是标记为Confirmed和Tentative的track
|
94 |
+
self.tracks = [t for t in self.tracks if not t.is_deleted()]
|
95 |
+
|
96 |
+
# Update distance metric.
|
97 |
+
active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
|
98 |
+
features, targets = [], []
|
99 |
+
for track in self.tracks:
|
100 |
+
# 获取所有Confirmed状态的track id
|
101 |
+
if not track.is_confirmed():
|
102 |
+
continue
|
103 |
+
features += track.features # 将Confirmed状态的track的features添加到features列表
|
104 |
+
# 获取每个feature对应的trackid
|
105 |
+
targets += [track.track_id for _ in track.features]
|
106 |
+
track.features = []
|
107 |
+
# 距离度量中的特征集更新
|
108 |
+
self.metric.partial_fit(
|
109 |
+
np.asarray(features), np.asarray(targets), active_targets)
|
110 |
+
|
111 |
+
def _match(self, detections):
|
112 |
+
|
113 |
+
def gated_metric(tracks, dets, track_indices, detection_indices):
|
114 |
+
features = np.array([dets[i].feature for i in detection_indices])
|
115 |
+
targets = np.array([tracks[i].track_id for i in track_indices])
|
116 |
+
|
117 |
+
# 通过最近邻(余弦距离)计算出成本矩阵(代价矩阵)
|
118 |
+
cost_matrix = self.metric.distance(features, targets)
|
119 |
+
# 计算门控后的成本矩阵(代价矩阵)
|
120 |
+
cost_matrix = linear_assignment.gate_cost_matrix(
|
121 |
+
self.kf, cost_matrix, tracks, dets, track_indices,
|
122 |
+
detection_indices)
|
123 |
+
|
124 |
+
return cost_matrix
|
125 |
+
|
126 |
+
# Split track set into confirmed and unconfirmed tracks.
|
127 |
+
# 区分开confirmed tracks和unconfirmed tracks
|
128 |
+
confirmed_tracks = [
|
129 |
+
i for i, t in enumerate(self.tracks) if t.is_confirmed()]
|
130 |
+
unconfirmed_tracks = [
|
131 |
+
i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
|
132 |
+
|
133 |
+
# Associate confirmed tracks using appearance features.
|
134 |
+
# 对确定态的轨迹进行级联匹配,得到匹配的tracks、不匹配的tracks、不匹配的detections
|
135 |
+
# matching_cascade 根据特征将检测框匹配到确认的轨迹。
|
136 |
+
# 传入门控后的成本矩阵
|
137 |
+
matches_a, unmatched_tracks_a, unmatched_detections = \
|
138 |
+
linear_assignment.matching_cascade(
|
139 |
+
gated_metric, self.metric.matching_threshold, self.max_age,
|
140 |
+
self.tracks, detections, confirmed_tracks)
|
141 |
+
|
142 |
+
# Associate remaining tracks together with unconfirmed tracks using IOU.
|
143 |
+
# 将未确定态的轨迹和刚刚没有匹配上的轨迹组合为 iou_track_candidates
|
144 |
+
# 并进行基于IoU的匹配
|
145 |
+
iou_track_candidates = unconfirmed_tracks + [
|
146 |
+
k for k in unmatched_tracks_a if
|
147 |
+
self.tracks[k].time_since_update == 1] # 刚刚没有匹配上的轨迹
|
148 |
+
unmatched_tracks_a = [
|
149 |
+
k for k in unmatched_tracks_a if
|
150 |
+
self.tracks[k].time_since_update != 1] # 并非刚刚没有匹配上的轨迹
|
151 |
+
# 对级联匹配中还没有匹配成功的目标再进行IoU匹配
|
152 |
+
# min_cost_matching 使用匈牙利算法解决线性分配问题。
|
153 |
+
# 传入 iou_cost,尝试关联剩余的轨迹与未确认的轨迹。
|
154 |
+
matches_b, unmatched_tracks_b, unmatched_detections = \
|
155 |
+
linear_assignment.min_cost_matching(
|
156 |
+
iou_matching.iou_cost, self.max_iou_distance, self.tracks,
|
157 |
+
detections, iou_track_candidates, unmatched_detections)
|
158 |
+
|
159 |
+
matches = matches_a + matches_b # 组合两部分匹配
|
160 |
+
unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
|
161 |
+
return matches, unmatched_tracks, unmatched_detections
|
162 |
+
|
163 |
+
def _initiate_track(self, detection):
|
164 |
+
mean, covariance = self.kf.initiate(detection.to_xyah())
|
165 |
+
self.tracks.append(Track(
|
166 |
+
mean, covariance, self._next_id, self.n_init, self.max_age,
|
167 |
+
detection.feature))
|
168 |
+
self._next_id += 1
|
deep_sort/utils/__init__.py
ADDED
File without changes
|
deep_sort/utils/asserts.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import environ
|
2 |
+
|
3 |
+
|
4 |
+
def assert_in(file, files_to_check):
|
5 |
+
if file not in files_to_check:
|
6 |
+
raise AssertionError("{} does not exist in the list".format(str(file)))
|
7 |
+
return True
|
8 |
+
|
9 |
+
|
10 |
+
def assert_in_env(check_list: list):
|
11 |
+
for item in check_list:
|
12 |
+
assert_in(item, environ.keys())
|
13 |
+
return True
|
deep_sort/utils/draw.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
|
5 |
+
|
6 |
+
|
7 |
+
def compute_color_for_labels(label):
|
8 |
+
"""
|
9 |
+
Simple function that adds fixed color depending on the class
|
10 |
+
"""
|
11 |
+
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
|
12 |
+
return tuple(color)
|
13 |
+
|
14 |
+
|
15 |
+
def draw_boxes(img, bbox, identities=None, offset=(0,0)):
|
16 |
+
for i,box in enumerate(bbox):
|
17 |
+
x1,y1,x2,y2 = [int(i) for i in box]
|
18 |
+
x1 += offset[0]
|
19 |
+
x2 += offset[0]
|
20 |
+
y1 += offset[1]
|
21 |
+
y2 += offset[1]
|
22 |
+
# box text and bar
|
23 |
+
id = int(identities[i]) if identities is not None else 0
|
24 |
+
color = compute_color_for_labels(id)
|
25 |
+
label = '{}{:d}'.format("", id)
|
26 |
+
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0]
|
27 |
+
cv2.rectangle(img,(x1, y1),(x2,y2),color,3)
|
28 |
+
cv2.rectangle(img,(x1, y1),(x1+t_size[0]+3,y1+t_size[1]+4), color,-1)
|
29 |
+
cv2.putText(img,label,(x1,y1+t_size[1]+4), cv2.FONT_HERSHEY_PLAIN, 2, [255,255,255], 2)
|
30 |
+
return img
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
for i in range(82):
|
36 |
+
print(compute_color_for_labels(i))
|
deep_sort/utils/evaluation.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import copy
|
4 |
+
import motmetrics as mm
|
5 |
+
mm.lap.default_solver = 'lap'
|
6 |
+
from utils.io import read_results, unzip_objs
|
7 |
+
|
8 |
+
|
9 |
+
class Evaluator(object):
|
10 |
+
|
11 |
+
def __init__(self, data_root, seq_name, data_type):
|
12 |
+
self.data_root = data_root
|
13 |
+
self.seq_name = seq_name
|
14 |
+
self.data_type = data_type
|
15 |
+
|
16 |
+
self.load_annotations()
|
17 |
+
self.reset_accumulator()
|
18 |
+
|
19 |
+
def load_annotations(self):
|
20 |
+
assert self.data_type == 'mot'
|
21 |
+
|
22 |
+
gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
|
23 |
+
self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
|
24 |
+
self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
|
25 |
+
|
26 |
+
def reset_accumulator(self):
|
27 |
+
self.acc = mm.MOTAccumulator(auto_id=True)
|
28 |
+
|
29 |
+
def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
|
30 |
+
# results
|
31 |
+
trk_tlwhs = np.copy(trk_tlwhs)
|
32 |
+
trk_ids = np.copy(trk_ids)
|
33 |
+
|
34 |
+
# gts
|
35 |
+
gt_objs = self.gt_frame_dict.get(frame_id, [])
|
36 |
+
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
|
37 |
+
|
38 |
+
# ignore boxes
|
39 |
+
ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
|
40 |
+
ignore_tlwhs = unzip_objs(ignore_objs)[0]
|
41 |
+
|
42 |
+
|
43 |
+
# remove ignored results
|
44 |
+
keep = np.ones(len(trk_tlwhs), dtype=bool)
|
45 |
+
iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
|
46 |
+
if len(iou_distance) > 0:
|
47 |
+
match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
|
48 |
+
match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
|
49 |
+
match_ious = iou_distance[match_is, match_js]
|
50 |
+
|
51 |
+
match_js = np.asarray(match_js, dtype=int)
|
52 |
+
match_js = match_js[np.logical_not(np.isnan(match_ious))]
|
53 |
+
keep[match_js] = False
|
54 |
+
trk_tlwhs = trk_tlwhs[keep]
|
55 |
+
trk_ids = trk_ids[keep]
|
56 |
+
|
57 |
+
# get distance matrix
|
58 |
+
iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
|
59 |
+
|
60 |
+
# acc
|
61 |
+
self.acc.update(gt_ids, trk_ids, iou_distance)
|
62 |
+
|
63 |
+
if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
|
64 |
+
events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
|
65 |
+
else:
|
66 |
+
events = None
|
67 |
+
return events
|
68 |
+
|
69 |
+
def eval_file(self, filename):
|
70 |
+
self.reset_accumulator()
|
71 |
+
|
72 |
+
result_frame_dict = read_results(filename, self.data_type, is_gt=False)
|
73 |
+
frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
|
74 |
+
for frame_id in frames:
|
75 |
+
trk_objs = result_frame_dict.get(frame_id, [])
|
76 |
+
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
|
77 |
+
self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
|
78 |
+
|
79 |
+
return self.acc
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
|
83 |
+
names = copy.deepcopy(names)
|
84 |
+
if metrics is None:
|
85 |
+
metrics = mm.metrics.motchallenge_metrics
|
86 |
+
metrics = copy.deepcopy(metrics)
|
87 |
+
|
88 |
+
mh = mm.metrics.create()
|
89 |
+
summary = mh.compute_many(
|
90 |
+
accs,
|
91 |
+
metrics=metrics,
|
92 |
+
names=names,
|
93 |
+
generate_overall=True
|
94 |
+
)
|
95 |
+
|
96 |
+
return summary
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def save_summary(summary, filename):
|
100 |
+
import pandas as pd
|
101 |
+
writer = pd.ExcelWriter(filename)
|
102 |
+
summary.to_excel(writer)
|
103 |
+
writer.save()
|
deep_sort/utils/io.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# from utils.log import get_logger
|
6 |
+
|
7 |
+
|
8 |
+
def write_results(filename, results, data_type):
|
9 |
+
if data_type == 'mot':
|
10 |
+
save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
|
11 |
+
elif data_type == 'kitti':
|
12 |
+
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
|
13 |
+
else:
|
14 |
+
raise ValueError(data_type)
|
15 |
+
|
16 |
+
with open(filename, 'w') as f:
|
17 |
+
for frame_id, tlwhs, track_ids in results:
|
18 |
+
if data_type == 'kitti':
|
19 |
+
frame_id -= 1
|
20 |
+
for tlwh, track_id in zip(tlwhs, track_ids):
|
21 |
+
if track_id < 0:
|
22 |
+
continue
|
23 |
+
x1, y1, w, h = tlwh
|
24 |
+
x2, y2 = x1 + w, y1 + h
|
25 |
+
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
|
26 |
+
f.write(line)
|
27 |
+
|
28 |
+
|
29 |
+
# def write_results(filename, results_dict: Dict, data_type: str):
|
30 |
+
# if not filename:
|
31 |
+
# return
|
32 |
+
# path = os.path.dirname(filename)
|
33 |
+
# if not os.path.exists(path):
|
34 |
+
# os.makedirs(path)
|
35 |
+
|
36 |
+
# if data_type in ('mot', 'mcmot', 'lab'):
|
37 |
+
# save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
38 |
+
# elif data_type == 'kitti':
|
39 |
+
# save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
|
40 |
+
# else:
|
41 |
+
# raise ValueError(data_type)
|
42 |
+
|
43 |
+
# with open(filename, 'w') as f:
|
44 |
+
# for frame_id, frame_data in results_dict.items():
|
45 |
+
# if data_type == 'kitti':
|
46 |
+
# frame_id -= 1
|
47 |
+
# for tlwh, track_id in frame_data:
|
48 |
+
# if track_id < 0:
|
49 |
+
# continue
|
50 |
+
# x1, y1, w, h = tlwh
|
51 |
+
# x2, y2 = x1 + w, y1 + h
|
52 |
+
# line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
|
53 |
+
# f.write(line)
|
54 |
+
# logger.info('Save results to {}'.format(filename))
|
55 |
+
|
56 |
+
|
57 |
+
def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
|
58 |
+
if data_type in ('mot', 'lab'):
|
59 |
+
read_fun = read_mot_results
|
60 |
+
else:
|
61 |
+
raise ValueError('Unknown data type: {}'.format(data_type))
|
62 |
+
|
63 |
+
return read_fun(filename, is_gt, is_ignore)
|
64 |
+
|
65 |
+
|
66 |
+
"""
|
67 |
+
labels={'ped', ... % 1
|
68 |
+
'person_on_vhcl', ... % 2
|
69 |
+
'car', ... % 3
|
70 |
+
'bicycle', ... % 4
|
71 |
+
'mbike', ... % 5
|
72 |
+
'non_mot_vhcl', ... % 6
|
73 |
+
'static_person', ... % 7
|
74 |
+
'distractor', ... % 8
|
75 |
+
'occluder', ... % 9
|
76 |
+
'occluder_on_grnd', ... %10
|
77 |
+
'occluder_full', ... % 11
|
78 |
+
'reflection', ... % 12
|
79 |
+
'crowd' ... % 13
|
80 |
+
};
|
81 |
+
"""
|
82 |
+
|
83 |
+
|
84 |
+
def read_mot_results(filename, is_gt, is_ignore):
|
85 |
+
valid_labels = {1}
|
86 |
+
ignore_labels = {2, 7, 8, 12}
|
87 |
+
results_dict = dict()
|
88 |
+
if os.path.isfile(filename):
|
89 |
+
with open(filename, 'r') as f:
|
90 |
+
for line in f.readlines():
|
91 |
+
linelist = line.split(',')
|
92 |
+
if len(linelist) < 7:
|
93 |
+
continue
|
94 |
+
fid = int(linelist[0])
|
95 |
+
if fid < 1:
|
96 |
+
continue
|
97 |
+
results_dict.setdefault(fid, list())
|
98 |
+
|
99 |
+
if is_gt:
|
100 |
+
if 'MOT16-' in filename or 'MOT17-' in filename:
|
101 |
+
label = int(float(linelist[7]))
|
102 |
+
mark = int(float(linelist[6]))
|
103 |
+
if mark == 0 or label not in valid_labels:
|
104 |
+
continue
|
105 |
+
score = 1
|
106 |
+
elif is_ignore:
|
107 |
+
if 'MOT16-' in filename or 'MOT17-' in filename:
|
108 |
+
label = int(float(linelist[7]))
|
109 |
+
vis_ratio = float(linelist[8])
|
110 |
+
if label not in ignore_labels and vis_ratio >= 0:
|
111 |
+
continue
|
112 |
+
else:
|
113 |
+
continue
|
114 |
+
score = 1
|
115 |
+
else:
|
116 |
+
score = float(linelist[6])
|
117 |
+
|
118 |
+
tlwh = tuple(map(float, linelist[2:6]))
|
119 |
+
target_id = int(linelist[1])
|
120 |
+
|
121 |
+
results_dict[fid].append((tlwh, target_id, score))
|
122 |
+
|
123 |
+
return results_dict
|
124 |
+
|
125 |
+
|
126 |
+
def unzip_objs(objs):
|
127 |
+
if len(objs) > 0:
|
128 |
+
tlwhs, ids, scores = zip(*objs)
|
129 |
+
else:
|
130 |
+
tlwhs, ids, scores = [], [], []
|
131 |
+
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
|
132 |
+
|
133 |
+
return tlwhs, ids, scores
|
deep_sort/utils/json_logger.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
References:
|
3 |
+
https://medium.com/analytics-vidhya/creating-a-custom-logging-mechanism-for-real-time-object-detection-using-tdd-4ca2cfcd0a2f
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
from os import makedirs
|
7 |
+
from os.path import exists, join
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
|
11 |
+
class JsonMeta(object):
|
12 |
+
HOURS = 3
|
13 |
+
MINUTES = 59
|
14 |
+
SECONDS = 59
|
15 |
+
PATH_TO_SAVE = 'LOGS'
|
16 |
+
DEFAULT_FILE_NAME = 'remaining'
|
17 |
+
|
18 |
+
|
19 |
+
class BaseJsonLogger(object):
|
20 |
+
"""
|
21 |
+
This is the base class that returns __dict__ of its own
|
22 |
+
it also returns the dicts of objects in the attributes that are list instances
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
def dic(self):
|
27 |
+
# returns dicts of objects
|
28 |
+
out = {}
|
29 |
+
for k, v in self.__dict__.items():
|
30 |
+
if hasattr(v, 'dic'):
|
31 |
+
out[k] = v.dic()
|
32 |
+
elif isinstance(v, list):
|
33 |
+
out[k] = self.list(v)
|
34 |
+
else:
|
35 |
+
out[k] = v
|
36 |
+
return out
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def list(values):
|
40 |
+
# applies the dic method on items in the list
|
41 |
+
return [v.dic() if hasattr(v, 'dic') else v for v in values]
|
42 |
+
|
43 |
+
|
44 |
+
class Label(BaseJsonLogger):
|
45 |
+
"""
|
46 |
+
For each bounding box there are various categories with confidences. Label class keeps track of that information.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, category: str, confidence: float):
|
50 |
+
self.category = category
|
51 |
+
self.confidence = confidence
|
52 |
+
|
53 |
+
|
54 |
+
class Bbox(BaseJsonLogger):
|
55 |
+
"""
|
56 |
+
This module stores the information for each frame and use them in JsonParser
|
57 |
+
Attributes:
|
58 |
+
labels (list): List of label module.
|
59 |
+
top (int):
|
60 |
+
left (int):
|
61 |
+
width (int):
|
62 |
+
height (int):
|
63 |
+
|
64 |
+
Args:
|
65 |
+
bbox_id (float):
|
66 |
+
top (int):
|
67 |
+
left (int):
|
68 |
+
width (int):
|
69 |
+
height (int):
|
70 |
+
|
71 |
+
References:
|
72 |
+
Check Label module for better understanding.
|
73 |
+
|
74 |
+
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, bbox_id, top, left, width, height):
|
78 |
+
self.labels = []
|
79 |
+
self.bbox_id = bbox_id
|
80 |
+
self.top = top
|
81 |
+
self.left = left
|
82 |
+
self.width = width
|
83 |
+
self.height = height
|
84 |
+
|
85 |
+
def add_label(self, category, confidence):
|
86 |
+
# adds category and confidence only if top_k is not exceeded.
|
87 |
+
self.labels.append(Label(category, confidence))
|
88 |
+
|
89 |
+
def labels_full(self, value):
|
90 |
+
return len(self.labels) == value
|
91 |
+
|
92 |
+
|
93 |
+
class Frame(BaseJsonLogger):
|
94 |
+
"""
|
95 |
+
This module stores the information for each frame and use them in JsonParser
|
96 |
+
Attributes:
|
97 |
+
timestamp (float): The elapsed time of captured frame
|
98 |
+
frame_id (int): The frame number of the captured video
|
99 |
+
bboxes (list of Bbox objects): Stores the list of bbox objects.
|
100 |
+
|
101 |
+
References:
|
102 |
+
Check Bbox class for better information
|
103 |
+
|
104 |
+
Args:
|
105 |
+
timestamp (float):
|
106 |
+
frame_id (int):
|
107 |
+
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, frame_id: int, timestamp: float = None):
|
111 |
+
self.frame_id = frame_id
|
112 |
+
self.timestamp = timestamp
|
113 |
+
self.bboxes = []
|
114 |
+
|
115 |
+
def add_bbox(self, bbox_id: int, top: int, left: int, width: int, height: int):
|
116 |
+
bboxes_ids = [bbox.bbox_id for bbox in self.bboxes]
|
117 |
+
if bbox_id not in bboxes_ids:
|
118 |
+
self.bboxes.append(Bbox(bbox_id, top, left, width, height))
|
119 |
+
else:
|
120 |
+
raise ValueError("Frame with id: {} already has a Bbox with id: {}".format(self.frame_id, bbox_id))
|
121 |
+
|
122 |
+
def add_label_to_bbox(self, bbox_id: int, category: str, confidence: float):
|
123 |
+
bboxes = {bbox.id: bbox for bbox in self.bboxes}
|
124 |
+
if bbox_id in bboxes.keys():
|
125 |
+
res = bboxes.get(bbox_id)
|
126 |
+
res.add_label(category, confidence)
|
127 |
+
else:
|
128 |
+
raise ValueError('the bbox with id: {} does not exists!'.format(bbox_id))
|
129 |
+
|
130 |
+
|
131 |
+
class BboxToJsonLogger(BaseJsonLogger):
|
132 |
+
"""
|
133 |
+
ُ This module is designed to automate the task of logging jsons. An example json is used
|
134 |
+
to show the contents of json file shortly
|
135 |
+
Example:
|
136 |
+
{
|
137 |
+
"video_details": {
|
138 |
+
"frame_width": 1920,
|
139 |
+
"frame_height": 1080,
|
140 |
+
"frame_rate": 20,
|
141 |
+
"video_name": "/home/gpu/codes/MSD/pedestrian_2/project/public/camera1.avi"
|
142 |
+
},
|
143 |
+
"frames": [
|
144 |
+
{
|
145 |
+
"frame_id": 329,
|
146 |
+
"timestamp": 3365.1254
|
147 |
+
"bboxes": [
|
148 |
+
{
|
149 |
+
"labels": [
|
150 |
+
{
|
151 |
+
"category": "pedestrian",
|
152 |
+
"confidence": 0.9
|
153 |
+
}
|
154 |
+
],
|
155 |
+
"bbox_id": 0,
|
156 |
+
"top": 1257,
|
157 |
+
"left": 138,
|
158 |
+
"width": 68,
|
159 |
+
"height": 109
|
160 |
+
}
|
161 |
+
]
|
162 |
+
}],
|
163 |
+
|
164 |
+
Attributes:
|
165 |
+
frames (dict): It's a dictionary that maps each frame_id to json attributes.
|
166 |
+
video_details (dict): information about video file.
|
167 |
+
top_k_labels (int): shows the allowed number of labels
|
168 |
+
start_time (datetime object): we use it to automate the json output by time.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
top_k_labels (int): shows the allowed number of labels
|
172 |
+
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, top_k_labels: int = 1):
|
176 |
+
self.frames = {}
|
177 |
+
self.video_details = self.video_details = dict(frame_width=None, frame_height=None, frame_rate=None,
|
178 |
+
video_name=None)
|
179 |
+
self.top_k_labels = top_k_labels
|
180 |
+
self.start_time = datetime.now()
|
181 |
+
|
182 |
+
def set_top_k(self, value):
|
183 |
+
self.top_k_labels = value
|
184 |
+
|
185 |
+
def frame_exists(self, frame_id: int) -> bool:
|
186 |
+
"""
|
187 |
+
Args:
|
188 |
+
frame_id (int):
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
bool: true if frame_id is recognized
|
192 |
+
"""
|
193 |
+
return frame_id in self.frames.keys()
|
194 |
+
|
195 |
+
def add_frame(self, frame_id: int, timestamp: float = None) -> None:
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
frame_id (int):
|
199 |
+
timestamp (float): opencv captured frame time property
|
200 |
+
|
201 |
+
Raises:
|
202 |
+
ValueError: if frame_id would not exist in class frames attribute
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
None
|
206 |
+
|
207 |
+
"""
|
208 |
+
if not self.frame_exists(frame_id):
|
209 |
+
self.frames[frame_id] = Frame(frame_id, timestamp)
|
210 |
+
else:
|
211 |
+
raise ValueError("Frame id: {} already exists".format(frame_id))
|
212 |
+
|
213 |
+
def bbox_exists(self, frame_id: int, bbox_id: int) -> bool:
|
214 |
+
"""
|
215 |
+
Args:
|
216 |
+
frame_id:
|
217 |
+
bbox_id:
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
bool: if bbox exists in frame bboxes list
|
221 |
+
"""
|
222 |
+
bboxes = []
|
223 |
+
if self.frame_exists(frame_id=frame_id):
|
224 |
+
bboxes = [bbox.bbox_id for bbox in self.frames[frame_id].bboxes]
|
225 |
+
return bbox_id in bboxes
|
226 |
+
|
227 |
+
def find_bbox(self, frame_id: int, bbox_id: int):
|
228 |
+
"""
|
229 |
+
|
230 |
+
Args:
|
231 |
+
frame_id:
|
232 |
+
bbox_id:
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
bbox_id (int):
|
236 |
+
|
237 |
+
Raises:
|
238 |
+
ValueError: if bbox_id does not exist in the bbox list of specific frame.
|
239 |
+
"""
|
240 |
+
if not self.bbox_exists(frame_id, bbox_id):
|
241 |
+
raise ValueError("frame with id: {} does not contain bbox with id: {}".format(frame_id, bbox_id))
|
242 |
+
bboxes = {bbox.bbox_id: bbox for bbox in self.frames[frame_id].bboxes}
|
243 |
+
return bboxes.get(bbox_id)
|
244 |
+
|
245 |
+
def add_bbox_to_frame(self, frame_id: int, bbox_id: int, top: int, left: int, width: int, height: int) -> None:
|
246 |
+
"""
|
247 |
+
|
248 |
+
Args:
|
249 |
+
frame_id (int):
|
250 |
+
bbox_id (int):
|
251 |
+
top (int):
|
252 |
+
left (int):
|
253 |
+
width (int):
|
254 |
+
height (int):
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
None
|
258 |
+
|
259 |
+
Raises:
|
260 |
+
ValueError: if bbox_id already exist in frame information with frame_id
|
261 |
+
ValueError: if frame_id does not exist in frames attribute
|
262 |
+
"""
|
263 |
+
if self.frame_exists(frame_id):
|
264 |
+
frame = self.frames[frame_id]
|
265 |
+
if not self.bbox_exists(frame_id, bbox_id):
|
266 |
+
frame.add_bbox(bbox_id, top, left, width, height)
|
267 |
+
else:
|
268 |
+
raise ValueError(
|
269 |
+
"frame with frame_id: {} already contains the bbox with id: {} ".format(frame_id, bbox_id))
|
270 |
+
else:
|
271 |
+
raise ValueError("frame with frame_id: {} does not exist".format(frame_id))
|
272 |
+
|
273 |
+
def add_label_to_bbox(self, frame_id: int, bbox_id: int, category: str, confidence: float):
|
274 |
+
"""
|
275 |
+
Args:
|
276 |
+
frame_id:
|
277 |
+
bbox_id:
|
278 |
+
category:
|
279 |
+
confidence: the confidence value returned from yolo detection
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
None
|
283 |
+
|
284 |
+
Raises:
|
285 |
+
ValueError: if labels quota (top_k_labels) exceeds.
|
286 |
+
"""
|
287 |
+
bbox = self.find_bbox(frame_id, bbox_id)
|
288 |
+
if not bbox.labels_full(self.top_k_labels):
|
289 |
+
bbox.add_label(category, confidence)
|
290 |
+
else:
|
291 |
+
raise ValueError("labels in frame_id: {}, bbox_id: {} is fulled".format(frame_id, bbox_id))
|
292 |
+
|
293 |
+
def add_video_details(self, frame_width: int = None, frame_height: int = None, frame_rate: int = None,
|
294 |
+
video_name: str = None):
|
295 |
+
self.video_details['frame_width'] = frame_width
|
296 |
+
self.video_details['frame_height'] = frame_height
|
297 |
+
self.video_details['frame_rate'] = frame_rate
|
298 |
+
self.video_details['video_name'] = video_name
|
299 |
+
|
300 |
+
def output(self):
|
301 |
+
output = {'video_details': self.video_details}
|
302 |
+
result = list(self.frames.values())
|
303 |
+
output['frames'] = [item.dic() for item in result]
|
304 |
+
return output
|
305 |
+
|
306 |
+
def json_output(self, output_name):
|
307 |
+
"""
|
308 |
+
Args:
|
309 |
+
output_name:
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
None
|
313 |
+
|
314 |
+
Notes:
|
315 |
+
It creates the json output with `output_name` name.
|
316 |
+
"""
|
317 |
+
if not output_name.endswith('.json'):
|
318 |
+
output_name += '.json'
|
319 |
+
with open(output_name, 'w') as file:
|
320 |
+
json.dump(self.output(), file)
|
321 |
+
file.close()
|
322 |
+
|
323 |
+
def set_start(self):
|
324 |
+
self.start_time = datetime.now()
|
325 |
+
|
326 |
+
def schedule_output_by_time(self, output_dir=JsonMeta.PATH_TO_SAVE, hours: int = 0, minutes: int = 0,
|
327 |
+
seconds: int = 60) -> None:
|
328 |
+
"""
|
329 |
+
Notes:
|
330 |
+
Creates folder and then periodically stores the jsons on that address.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
output_dir (str): the directory where output files will be stored
|
334 |
+
hours (int):
|
335 |
+
minutes (int):
|
336 |
+
seconds (int):
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
None
|
340 |
+
|
341 |
+
"""
|
342 |
+
end = datetime.now()
|
343 |
+
interval = 0
|
344 |
+
interval += abs(min([hours, JsonMeta.HOURS]) * 3600)
|
345 |
+
interval += abs(min([minutes, JsonMeta.MINUTES]) * 60)
|
346 |
+
interval += abs(min([seconds, JsonMeta.SECONDS]))
|
347 |
+
diff = (end - self.start_time).seconds
|
348 |
+
|
349 |
+
if diff > interval:
|
350 |
+
output_name = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '.json'
|
351 |
+
if not exists(output_dir):
|
352 |
+
makedirs(output_dir)
|
353 |
+
output = join(output_dir, output_name)
|
354 |
+
self.json_output(output_name=output)
|
355 |
+
self.frames = {}
|
356 |
+
self.start_time = datetime.now()
|
357 |
+
|
358 |
+
def schedule_output_by_frames(self, frames_quota, frame_counter, output_dir=JsonMeta.PATH_TO_SAVE):
|
359 |
+
"""
|
360 |
+
saves as the number of frames quota increases higher.
|
361 |
+
:param frames_quota:
|
362 |
+
:param frame_counter:
|
363 |
+
:param output_dir:
|
364 |
+
:return:
|
365 |
+
"""
|
366 |
+
pass
|
367 |
+
|
368 |
+
def flush(self, output_dir):
|
369 |
+
"""
|
370 |
+
Notes:
|
371 |
+
We use this function to output jsons whenever possible.
|
372 |
+
like the time that we exit the while loop of opencv.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
output_dir:
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
None
|
379 |
+
|
380 |
+
"""
|
381 |
+
filename = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '-remaining.json'
|
382 |
+
output = join(output_dir, filename)
|
383 |
+
self.json_output(output_name=output)
|
deep_sort/utils/log.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
def get_logger(name='root'):
|
5 |
+
formatter = logging.Formatter(
|
6 |
+
# fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s')
|
7 |
+
fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
8 |
+
|
9 |
+
handler = logging.StreamHandler()
|
10 |
+
handler.setFormatter(formatter)
|
11 |
+
|
12 |
+
logger = logging.getLogger(name)
|
13 |
+
logger.setLevel(logging.INFO)
|
14 |
+
logger.addHandler(handler)
|
15 |
+
return logger
|
16 |
+
|
17 |
+
|
deep_sort/utils/parser.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
from easydict import EasyDict as edict
|
4 |
+
|
5 |
+
class YamlParser(edict):
|
6 |
+
"""
|
7 |
+
This is yaml parser based on EasyDict.
|
8 |
+
"""
|
9 |
+
def __init__(self, cfg_dict=None, config_file=None):
|
10 |
+
if cfg_dict is None:
|
11 |
+
cfg_dict = {}
|
12 |
+
|
13 |
+
if config_file is not None:
|
14 |
+
assert(os.path.isfile(config_file))
|
15 |
+
with open(config_file, 'r') as fo:
|
16 |
+
cfg_dict.update(yaml.load(fo.read()))
|
17 |
+
|
18 |
+
super(YamlParser, self).__init__(cfg_dict)
|
19 |
+
|
20 |
+
|
21 |
+
def merge_from_file(self, config_file):
|
22 |
+
with open(config_file, 'r') as fo:
|
23 |
+
#self.update(yaml.load(fo.read()))
|
24 |
+
self.update(yaml.load(fo.read(),Loader=yaml.FullLoader))
|
25 |
+
|
26 |
+
def merge_from_dict(self, config_dict):
|
27 |
+
self.update(config_dict)
|
28 |
+
|
29 |
+
|
30 |
+
def get_config(config_file=None):
|
31 |
+
return YamlParser(config_file=config_file)
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
cfg = YamlParser(config_file="../configs/yolov3.yaml")
|
36 |
+
cfg.merge_from_file("../configs/deep_sort.yaml")
|
37 |
+
|
38 |
+
import ipdb; ipdb.set_trace()
|