Commit
·
d5c53f9
0
Parent(s):
fresh start without image history
Browse files- .gitattributes +35 -0
- .gitignore +5 -0
- README.md +14 -0
- app.py +40 -0
- configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- configs/sam2/sam2_hiera_b+.yaml +113 -0
- configs/sam2/sam2_hiera_l.yaml +117 -0
- configs/sam2/sam2_hiera_s.yaml +116 -0
- configs/sam2/sam2_hiera_t.yaml +118 -0
- requirements.txt +9 -0
- sam2 +1 -0
- sam2segment_structure.py +887 -0
- yolo11n.pt +3 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.jpg
|
2 |
+
*.png
|
3 |
+
driver_182_30frame/
|
4 |
+
*.jpg
|
5 |
+
*.png
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DBDLD
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: The backdoor trigger demo
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from sam2segment_structure import generate_trigger_crop
|
3 |
+
import os
|
4 |
+
|
5 |
+
# 模拟 lane_data(后期你可以动态读取 JSON 或用户上传)
|
6 |
+
dummy_lane_data = {
|
7 |
+
"lanes": [[-2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2]],
|
8 |
+
"h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390],
|
9 |
+
"raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"
|
10 |
+
}
|
11 |
+
|
12 |
+
def process_trigger_with_path(input_image, save_path):
|
13 |
+
# 确保目录存在
|
14 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
15 |
+
|
16 |
+
# 保存图片到指定路径
|
17 |
+
input_image.save(save_path)
|
18 |
+
|
19 |
+
# 设置 dummy_lane_data 中 raw_file 为当前路径
|
20 |
+
dummy_lane_data["raw_file"] = save_path
|
21 |
+
|
22 |
+
# 调用主处理函数
|
23 |
+
crop_path, mask_path = generate_trigger_crop(save_path, dummy_lane_data)
|
24 |
+
return crop_path, mask_path
|
25 |
+
|
26 |
+
demo = gr.Interface(
|
27 |
+
fn=process_trigger_with_path,
|
28 |
+
inputs=[
|
29 |
+
gr.Image(type="pil", label="Upload Image"),
|
30 |
+
gr.Textbox(label="Path to Save Image (e.g. driver_182_30frame/06010513_0036.MP4/00270.jpg)")
|
31 |
+
],
|
32 |
+
outputs=[
|
33 |
+
gr.Image(type="filepath", label="Cropped Image"),
|
34 |
+
gr.Image(type="filepath", label="Cropped Mask")
|
35 |
+
],
|
36 |
+
title="DBDLD Trigger Demo",
|
37 |
+
description="Upload an image and specify the target save path. The crop and mask will be generated accordingly."
|
38 |
+
)
|
39 |
+
|
40 |
+
demo.launch()
|
configs/sam2.1/sam2.1_hiera_b+.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [64, 64]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [64, 64]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
no_obj_embed_spatial: true
|
93 |
+
# use high-resolution feature map in the SAM mask decoder
|
94 |
+
use_high_res_features_in_sam: true
|
95 |
+
# output 3 masks on the first click on initial conditioning frames
|
96 |
+
multimask_output_in_sam: true
|
97 |
+
# SAM heads
|
98 |
+
iou_prediction_use_sigmoid: True
|
99 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
100 |
+
use_obj_ptrs_in_encoder: true
|
101 |
+
add_tpos_enc_to_obj_ptrs: true
|
102 |
+
proj_tpos_enc_in_obj_ptrs: true
|
103 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
configs/sam2.1/sam2.1_hiera_l.yaml
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [64, 64]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [64, 64]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
no_obj_embed_spatial: true
|
97 |
+
# use high-resolution feature map in the SAM mask decoder
|
98 |
+
use_high_res_features_in_sam: true
|
99 |
+
# output 3 masks on the first click on initial conditioning frames
|
100 |
+
multimask_output_in_sam: true
|
101 |
+
# SAM heads
|
102 |
+
iou_prediction_use_sigmoid: True
|
103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
104 |
+
use_obj_ptrs_in_encoder: true
|
105 |
+
add_tpos_enc_to_obj_ptrs: true
|
106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
109 |
+
# object occlusion prediction
|
110 |
+
pred_obj_scores: true
|
111 |
+
pred_obj_scores_mlp: true
|
112 |
+
fixed_no_obj_ptr: true
|
113 |
+
# multimask tracking settings
|
114 |
+
multimask_output_for_tracking: true
|
115 |
+
use_multimask_token_for_obj_ptr: true
|
116 |
+
multimask_min_pt_num: 0
|
117 |
+
multimask_max_pt_num: 1
|
118 |
+
use_mlp_for_obj_ptr_proj: true
|
119 |
+
# Compilation flag
|
120 |
+
compile_image_encoder: False
|
configs/sam2.1/sam2.1_hiera_s.yaml
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [64, 64]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [64, 64]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
no_obj_embed_spatial: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: true
|
105 |
+
proj_tpos_enc_in_obj_ptrs: true
|
106 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
107 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
108 |
+
# object occlusion prediction
|
109 |
+
pred_obj_scores: true
|
110 |
+
pred_obj_scores_mlp: true
|
111 |
+
fixed_no_obj_ptr: true
|
112 |
+
# multimask tracking settings
|
113 |
+
multimask_output_for_tracking: true
|
114 |
+
use_multimask_token_for_obj_ptr: true
|
115 |
+
multimask_min_pt_num: 0
|
116 |
+
multimask_max_pt_num: 1
|
117 |
+
use_mlp_for_obj_ptr_proj: true
|
118 |
+
# Compilation flag
|
119 |
+
compile_image_encoder: False
|
configs/sam2.1/sam2.1_hiera_t.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 7, 2]
|
14 |
+
global_att_blocks: [5, 7, 9]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [64, 64]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [64, 64]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
# SAM decoder
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
no_obj_embed_spatial: true
|
97 |
+
# use high-resolution feature map in the SAM mask decoder
|
98 |
+
use_high_res_features_in_sam: true
|
99 |
+
# output 3 masks on the first click on initial conditioning frames
|
100 |
+
multimask_output_in_sam: true
|
101 |
+
# SAM heads
|
102 |
+
iou_prediction_use_sigmoid: True
|
103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
104 |
+
use_obj_ptrs_in_encoder: true
|
105 |
+
add_tpos_enc_to_obj_ptrs: true
|
106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
109 |
+
# object occlusion prediction
|
110 |
+
pred_obj_scores: true
|
111 |
+
pred_obj_scores_mlp: true
|
112 |
+
fixed_no_obj_ptr: true
|
113 |
+
# multimask tracking settings
|
114 |
+
multimask_output_for_tracking: true
|
115 |
+
use_multimask_token_for_obj_ptr: true
|
116 |
+
multimask_min_pt_num: 0
|
117 |
+
multimask_max_pt_num: 1
|
118 |
+
use_mlp_for_obj_ptr_proj: true
|
119 |
+
# Compilation flag
|
120 |
+
# HieraT does not currently support compilation, should always be set to False
|
121 |
+
compile_image_encoder: False
|
configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
scratch:
|
4 |
+
resolution: 1024
|
5 |
+
train_batch_size: 1
|
6 |
+
num_train_workers: 10
|
7 |
+
num_frames: 8
|
8 |
+
max_num_objects: 3
|
9 |
+
base_lr: 5.0e-6
|
10 |
+
vision_lr: 3.0e-06
|
11 |
+
phases_per_epoch: 1
|
12 |
+
num_epochs: 40
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
# PATHS to Dataset
|
16 |
+
img_folder: null # PATH to MOSE JPEGImages folder
|
17 |
+
gt_folder: null # PATH to MOSE Annotations folder
|
18 |
+
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
19 |
+
multiplier: 2
|
20 |
+
|
21 |
+
# Video transforms
|
22 |
+
vos:
|
23 |
+
train_transforms:
|
24 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
25 |
+
transforms:
|
26 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
27 |
+
consistent_transform: True
|
28 |
+
- _target_: training.dataset.transforms.RandomAffine
|
29 |
+
degrees: 25
|
30 |
+
shear: 20
|
31 |
+
image_interpolation: bilinear
|
32 |
+
consistent_transform: True
|
33 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
34 |
+
sizes: ${scratch.resolution}
|
35 |
+
square: true
|
36 |
+
consistent_transform: True
|
37 |
+
- _target_: training.dataset.transforms.ColorJitter
|
38 |
+
consistent_transform: True
|
39 |
+
brightness: 0.1
|
40 |
+
contrast: 0.03
|
41 |
+
saturation: 0.03
|
42 |
+
hue: null
|
43 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
44 |
+
p: 0.05
|
45 |
+
consistent_transform: True
|
46 |
+
- _target_: training.dataset.transforms.ColorJitter
|
47 |
+
consistent_transform: False
|
48 |
+
brightness: 0.1
|
49 |
+
contrast: 0.05
|
50 |
+
saturation: 0.05
|
51 |
+
hue: null
|
52 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
53 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
54 |
+
mean: [0.485, 0.456, 0.406]
|
55 |
+
std: [0.229, 0.224, 0.225]
|
56 |
+
|
57 |
+
trainer:
|
58 |
+
_target_: training.trainer.Trainer
|
59 |
+
mode: train_only
|
60 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
61 |
+
accelerator: cuda
|
62 |
+
seed_value: 123
|
63 |
+
|
64 |
+
model:
|
65 |
+
_target_: training.model.sam2.SAM2Train
|
66 |
+
image_encoder:
|
67 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
68 |
+
scalp: 1
|
69 |
+
trunk:
|
70 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
71 |
+
embed_dim: 112
|
72 |
+
num_heads: 2
|
73 |
+
drop_path_rate: 0.1
|
74 |
+
neck:
|
75 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
76 |
+
position_encoding:
|
77 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
78 |
+
num_pos_feats: 256
|
79 |
+
normalize: true
|
80 |
+
scale: null
|
81 |
+
temperature: 10000
|
82 |
+
d_model: 256
|
83 |
+
backbone_channel_list: [896, 448, 224, 112]
|
84 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
85 |
+
fpn_interp_model: nearest
|
86 |
+
|
87 |
+
memory_attention:
|
88 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
89 |
+
d_model: 256
|
90 |
+
pos_enc_at_input: true
|
91 |
+
layer:
|
92 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
93 |
+
activation: relu
|
94 |
+
dim_feedforward: 2048
|
95 |
+
dropout: 0.1
|
96 |
+
pos_enc_at_attn: false
|
97 |
+
self_attention:
|
98 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
99 |
+
rope_theta: 10000.0
|
100 |
+
feat_sizes: [64, 64]
|
101 |
+
embedding_dim: 256
|
102 |
+
num_heads: 1
|
103 |
+
downsample_rate: 1
|
104 |
+
dropout: 0.1
|
105 |
+
d_model: 256
|
106 |
+
pos_enc_at_cross_attn_keys: true
|
107 |
+
pos_enc_at_cross_attn_queries: false
|
108 |
+
cross_attention:
|
109 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
110 |
+
rope_theta: 10000.0
|
111 |
+
feat_sizes: [64, 64]
|
112 |
+
rope_k_repeat: True
|
113 |
+
embedding_dim: 256
|
114 |
+
num_heads: 1
|
115 |
+
downsample_rate: 1
|
116 |
+
dropout: 0.1
|
117 |
+
kv_in_dim: 64
|
118 |
+
num_layers: 4
|
119 |
+
|
120 |
+
memory_encoder:
|
121 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
122 |
+
out_dim: 64
|
123 |
+
position_encoding:
|
124 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
125 |
+
num_pos_feats: 64
|
126 |
+
normalize: true
|
127 |
+
scale: null
|
128 |
+
temperature: 10000
|
129 |
+
mask_downsampler:
|
130 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
131 |
+
kernel_size: 3
|
132 |
+
stride: 2
|
133 |
+
padding: 1
|
134 |
+
fuser:
|
135 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
136 |
+
layer:
|
137 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
138 |
+
dim: 256
|
139 |
+
kernel_size: 7
|
140 |
+
padding: 3
|
141 |
+
layer_scale_init_value: 1e-6
|
142 |
+
use_dwconv: True # depth-wise convs
|
143 |
+
num_layers: 2
|
144 |
+
|
145 |
+
num_maskmem: 7
|
146 |
+
image_size: ${scratch.resolution}
|
147 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
148 |
+
sigmoid_scale_for_mem_enc: 20.0
|
149 |
+
sigmoid_bias_for_mem_enc: -10.0
|
150 |
+
use_mask_input_as_output_without_sam: true
|
151 |
+
# Memory
|
152 |
+
directly_add_no_mem_embed: true
|
153 |
+
no_obj_embed_spatial: true
|
154 |
+
# use high-resolution feature map in the SAM mask decoder
|
155 |
+
use_high_res_features_in_sam: true
|
156 |
+
# output 3 masks on the first click on initial conditioning frames
|
157 |
+
multimask_output_in_sam: true
|
158 |
+
# SAM heads
|
159 |
+
iou_prediction_use_sigmoid: True
|
160 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
161 |
+
use_obj_ptrs_in_encoder: true
|
162 |
+
add_tpos_enc_to_obj_ptrs: true
|
163 |
+
proj_tpos_enc_in_obj_ptrs: true
|
164 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
165 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
166 |
+
# object occlusion prediction
|
167 |
+
pred_obj_scores: true
|
168 |
+
pred_obj_scores_mlp: true
|
169 |
+
fixed_no_obj_ptr: true
|
170 |
+
# multimask tracking settings
|
171 |
+
multimask_output_for_tracking: true
|
172 |
+
use_multimask_token_for_obj_ptr: true
|
173 |
+
multimask_min_pt_num: 0
|
174 |
+
multimask_max_pt_num: 1
|
175 |
+
use_mlp_for_obj_ptr_proj: true
|
176 |
+
# Compilation flag
|
177 |
+
# compile_image_encoder: False
|
178 |
+
|
179 |
+
####### Training specific params #######
|
180 |
+
# box/point input and corrections
|
181 |
+
prob_to_use_pt_input_for_train: 0.5
|
182 |
+
prob_to_use_pt_input_for_eval: 0.0
|
183 |
+
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
|
184 |
+
prob_to_use_box_input_for_eval: 0.0
|
185 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
186 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
187 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
188 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
189 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
190 |
+
# maximum 2 initial conditioning frames
|
191 |
+
num_init_cond_frames_for_train: 2
|
192 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
193 |
+
num_correction_pt_per_frame: 7
|
194 |
+
use_act_ckpt_iterative_pt_sampling: false
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
199 |
+
forward_backbone_per_frame_for_eval: True
|
200 |
+
|
201 |
+
|
202 |
+
data:
|
203 |
+
train:
|
204 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
205 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
206 |
+
batch_sizes:
|
207 |
+
- ${scratch.train_batch_size}
|
208 |
+
|
209 |
+
datasets:
|
210 |
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
211 |
+
dataset:
|
212 |
+
_target_: training.dataset.utils.ConcatDataset
|
213 |
+
datasets:
|
214 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
215 |
+
transforms: ${vos.train_transforms}
|
216 |
+
training: true
|
217 |
+
video_dataset:
|
218 |
+
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
|
219 |
+
img_folder: ${dataset.img_folder}
|
220 |
+
gt_folder: ${dataset.gt_folder}
|
221 |
+
file_list_txt: ${dataset.file_list_txt}
|
222 |
+
sampler:
|
223 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
224 |
+
num_frames: ${scratch.num_frames}
|
225 |
+
max_num_objects: ${scratch.max_num_objects}
|
226 |
+
multiplier: ${dataset.multiplier}
|
227 |
+
shuffle: True
|
228 |
+
num_workers: ${scratch.num_train_workers}
|
229 |
+
pin_memory: True
|
230 |
+
drop_last: True
|
231 |
+
collate_fn:
|
232 |
+
_target_: training.utils.data_utils.collate_fn
|
233 |
+
_partial_: true
|
234 |
+
dict_key: all
|
235 |
+
|
236 |
+
optim:
|
237 |
+
amp:
|
238 |
+
enabled: True
|
239 |
+
amp_dtype: bfloat16
|
240 |
+
|
241 |
+
optimizer:
|
242 |
+
_target_: torch.optim.AdamW
|
243 |
+
|
244 |
+
gradient_clip:
|
245 |
+
_target_: training.optimizer.GradientClipper
|
246 |
+
max_norm: 0.1
|
247 |
+
norm_type: 2
|
248 |
+
|
249 |
+
param_group_modifiers:
|
250 |
+
- _target_: training.optimizer.layer_decay_param_modifier
|
251 |
+
_partial_: True
|
252 |
+
layer_decay_value: 0.9
|
253 |
+
apply_to: 'image_encoder.trunk'
|
254 |
+
overrides:
|
255 |
+
- pattern: '*pos_embed*'
|
256 |
+
value: 1.0
|
257 |
+
|
258 |
+
options:
|
259 |
+
lr:
|
260 |
+
- scheduler:
|
261 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
262 |
+
start_value: ${scratch.base_lr}
|
263 |
+
end_value: ${divide:${scratch.base_lr},10}
|
264 |
+
- scheduler:
|
265 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
266 |
+
start_value: ${scratch.vision_lr}
|
267 |
+
end_value: ${divide:${scratch.vision_lr},10}
|
268 |
+
param_names:
|
269 |
+
- 'image_encoder.*'
|
270 |
+
weight_decay:
|
271 |
+
- scheduler:
|
272 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
273 |
+
value: 0.1
|
274 |
+
- scheduler:
|
275 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
276 |
+
value: 0.0
|
277 |
+
param_names:
|
278 |
+
- '*bias*'
|
279 |
+
module_cls_names: ['torch.nn.LayerNorm']
|
280 |
+
|
281 |
+
loss:
|
282 |
+
all:
|
283 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
284 |
+
weight_dict:
|
285 |
+
loss_mask: 20
|
286 |
+
loss_dice: 1
|
287 |
+
loss_iou: 1
|
288 |
+
loss_class: 1
|
289 |
+
supervise_all_iou: true
|
290 |
+
iou_use_l1_loss: true
|
291 |
+
pred_obj_scores: true
|
292 |
+
focal_gamma_obj_score: 0.0
|
293 |
+
focal_alpha_obj_score: -1.0
|
294 |
+
|
295 |
+
distributed:
|
296 |
+
backend: nccl
|
297 |
+
find_unused_parameters: True
|
298 |
+
|
299 |
+
logging:
|
300 |
+
tensorboard_writer:
|
301 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
302 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
303 |
+
flush_secs: 120
|
304 |
+
should_log: True
|
305 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
306 |
+
log_freq: 10
|
307 |
+
|
308 |
+
# initialize from a SAM 2 checkpoint
|
309 |
+
checkpoint:
|
310 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
311 |
+
save_freq: 0 # 0 only last checkpoint is saved.
|
312 |
+
model_weight_initializer:
|
313 |
+
_partial_: True
|
314 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
315 |
+
strict: True
|
316 |
+
ignore_unexpected_keys: null
|
317 |
+
ignore_missing_keys: null
|
318 |
+
|
319 |
+
state_dict:
|
320 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
321 |
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
322 |
+
ckpt_state_dict_keys: ['model']
|
323 |
+
|
324 |
+
launcher:
|
325 |
+
num_nodes: 1
|
326 |
+
gpus_per_node: 8
|
327 |
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
328 |
+
|
329 |
+
# SLURM args if running on a cluster
|
330 |
+
submitit:
|
331 |
+
partition: null
|
332 |
+
account: null
|
333 |
+
qos: null
|
334 |
+
cpus_per_task: 10
|
335 |
+
use_cluster: false
|
336 |
+
timeout_hour: 24
|
337 |
+
name: null
|
338 |
+
port_range: [10000, 65000]
|
339 |
+
|
configs/sam2/sam2_hiera_b+.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [64, 64]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [64, 64]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
# use high-resolution feature map in the SAM mask decoder
|
93 |
+
use_high_res_features_in_sam: true
|
94 |
+
# output 3 masks on the first click on initial conditioning frames
|
95 |
+
multimask_output_in_sam: true
|
96 |
+
# SAM heads
|
97 |
+
iou_prediction_use_sigmoid: True
|
98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
99 |
+
use_obj_ptrs_in_encoder: true
|
100 |
+
add_tpos_enc_to_obj_ptrs: false
|
101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
102 |
+
# object occlusion prediction
|
103 |
+
pred_obj_scores: true
|
104 |
+
pred_obj_scores_mlp: true
|
105 |
+
fixed_no_obj_ptr: true
|
106 |
+
# multimask tracking settings
|
107 |
+
multimask_output_for_tracking: true
|
108 |
+
use_multimask_token_for_obj_ptr: true
|
109 |
+
multimask_min_pt_num: 0
|
110 |
+
multimask_max_pt_num: 1
|
111 |
+
use_mlp_for_obj_ptr_proj: true
|
112 |
+
# Compilation flag
|
113 |
+
compile_image_encoder: False
|
configs/sam2/sam2_hiera_l.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [64, 64]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [64, 64]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
compile_image_encoder: False
|
configs/sam2/sam2_hiera_s.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [64, 64]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [64, 64]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
# use high-resolution feature map in the SAM mask decoder
|
96 |
+
use_high_res_features_in_sam: true
|
97 |
+
# output 3 masks on the first click on initial conditioning frames
|
98 |
+
multimask_output_in_sam: true
|
99 |
+
# SAM heads
|
100 |
+
iou_prediction_use_sigmoid: True
|
101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
102 |
+
use_obj_ptrs_in_encoder: true
|
103 |
+
add_tpos_enc_to_obj_ptrs: false
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
configs/sam2/sam2_hiera_t.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 7, 2]
|
14 |
+
global_att_blocks: [5, 7, 9]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [64, 64]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [64, 64]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
# SAM decoder
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
# HieraT does not currently support compilation, should always be set to False
|
118 |
+
compile_image_encoder: False
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
opencv-python
|
4 |
+
gradio
|
5 |
+
matplotlib
|
6 |
+
Pillow
|
7 |
+
ultralytics
|
8 |
+
diffusers
|
9 |
+
huggingface_hub
|
sam2
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/data_sdf/yifan/sam2
|
sam2segment_structure.py
ADDED
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import sys,os
|
3 |
+
# sys.path.append("/home/yifan/sam2")
|
4 |
+
# sys.path.append("/data_sdf/yifan/miniconda3/envs/sam2/lib/python3.10/site-packages")
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "sam2"))
|
7 |
+
from sam2.build_sam import build_sam2
|
8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
+
import torch
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from PIL import Image
|
12 |
+
import cv2
|
13 |
+
import random
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
16 |
+
device = torch.device("cuda")
|
17 |
+
sam2_checkpoint = hf_hub_download(
|
18 |
+
repo_id="Evan73/sam2-models",
|
19 |
+
filename="sam2.1_hiera_large.pt"
|
20 |
+
)
|
21 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
22 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
23 |
+
# global sam2_model
|
24 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
25 |
+
from ultralytics import YOLO
|
26 |
+
from diffusers.utils import load_image
|
27 |
+
import pickle
|
28 |
+
import os
|
29 |
+
import math
|
30 |
+
heatmap_zip = hf_hub_download(
|
31 |
+
repo_id="Evan73/attention-heatmaps",
|
32 |
+
filename="attention_heatmaps.zip"
|
33 |
+
)
|
34 |
+
import zipfile
|
35 |
+
import os
|
36 |
+
|
37 |
+
with zipfile.ZipFile(heatmap_zip, 'r') as zip_ref:
|
38 |
+
zip_ref.extractall("heatmaps_lda")
|
39 |
+
|
40 |
+
with open("heatmaps_lda/attention_heatmaps.pkl", "rb") as f:
|
41 |
+
heatmap_dict = pickle.load(f)
|
42 |
+
|
43 |
+
def load_yolov5_model():
|
44 |
+
# 使用YOLOv5官方模型加载器(需要安装yolov5)
|
45 |
+
# model = torch.hub.load('ultralytics/yolov11', 'yolov11s') # 可以根据需要选择不同大小的模型
|
46 |
+
model = YOLO("yolo11n.pt")
|
47 |
+
class_names = model.names # class index to name mapping
|
48 |
+
print("YOLOv11 Class Names:")
|
49 |
+
for idx, name in class_names.items():
|
50 |
+
print(f"{idx}: {name}")
|
51 |
+
return model
|
52 |
+
|
53 |
+
# 检查点是否在汽车区域内
|
54 |
+
def is_point_in_car_area(point, model, image):
|
55 |
+
"""
|
56 |
+
检查给定的点是否在车辆区域内
|
57 |
+
- point: 点的坐标 (x, y)
|
58 |
+
- model: YOLO模型
|
59 |
+
- image: 输入的图像
|
60 |
+
"""
|
61 |
+
# 使用YOLO模型进行物体检测
|
62 |
+
results = model(image) # 获取检测结果
|
63 |
+
|
64 |
+
# 获取汽车类别(根据模型调整类别ID)
|
65 |
+
# print("Detected classes:", results[0].boxes.cls.cpu().numpy())
|
66 |
+
car_class_id = [2, 5, 7] # COCO数据集中汽车类别通常为2,但需确认
|
67 |
+
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
68 |
+
|
69 |
+
# 遍历每个检测结果(支持批量处理,这里假设单张图像)
|
70 |
+
for result in results:
|
71 |
+
# 提取检测框的xyxy坐标、置信度、类别
|
72 |
+
boxes = result.boxes.xyxy.cpu().numpy() # 转换为左上和右下坐标
|
73 |
+
confidences = result.boxes.conf.cpu().numpy()
|
74 |
+
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
75 |
+
|
76 |
+
# 遍历每个检测框
|
77 |
+
for box, cls in zip(boxes, class_ids):
|
78 |
+
if cls in car_class_id:
|
79 |
+
x_min, y_min, x_max, y_max = box[:4]
|
80 |
+
# 绘制检测框(可选)
|
81 |
+
cv2.rectangle(image_bgr, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2)
|
82 |
+
# 检查点是否在框内
|
83 |
+
if (x_min <= point[0] <= x_max) and (y_min <= point[1] <= y_max):
|
84 |
+
cv2.imwrite("yolo_res.jpg", image_bgr)
|
85 |
+
return False
|
86 |
+
cv2.imwrite("yolo_res.jpg", image_bgr)
|
87 |
+
print(f"检测结果已保存至 yolo_res.jpg")
|
88 |
+
return True
|
89 |
+
|
90 |
+
|
91 |
+
def show_mask(mask, ax, image_path,random_color=False, borders=True, image=None, save_path=None):
|
92 |
+
"""
|
93 |
+
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
|
94 |
+
|
95 |
+
参数:
|
96 |
+
- `mask`: 掩码区域
|
97 |
+
- `ax`: 用于绘制的matplotlib轴
|
98 |
+
- `random_color`: 是否使用随机颜色
|
99 |
+
- `borders`: 是否显示边界
|
100 |
+
- `image`: 原始图像,用于绘制矩形框
|
101 |
+
- `save_path`: 保存结果图像的路径
|
102 |
+
"""
|
103 |
+
if random_color:
|
104 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
105 |
+
else:
|
106 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
107 |
+
|
108 |
+
h, w = mask.shape[-2:]
|
109 |
+
mask = mask.astype(np.uint8)
|
110 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
111 |
+
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
|
112 |
+
print("原始二值掩码已保存为 binary_mask.png")
|
113 |
+
if borders:
|
114 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
115 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
116 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=5)
|
117 |
+
# print(f"Mask unique values: {np.unique(mask)}")
|
118 |
+
# print(f"Max value in mask: {mask.max()}, Min value in mask: {mask.min()}")
|
119 |
+
# 如果提供了原始图像,绘制矩形框
|
120 |
+
# size = 100
|
121 |
+
colors = [
|
122 |
+
(255, 0, 0), # 红色
|
123 |
+
(0, 255, 0), # 绿色
|
124 |
+
(0, 0, 255), # 蓝色
|
125 |
+
(255, 255, 0), # 黄色
|
126 |
+
(255, 0, 255), # 品红色
|
127 |
+
(0, 255, 255), # 青色
|
128 |
+
(255, 128, 0), # 橙色
|
129 |
+
(128, 0, 255), # 紫色
|
130 |
+
(128, 128, 128), # 灰色
|
131 |
+
(0, 128, 0) # 深绿色
|
132 |
+
]
|
133 |
+
|
134 |
+
for idx, contour in enumerate(contours):
|
135 |
+
x, y, w, h = cv2.boundingRect(contour)
|
136 |
+
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
|
137 |
+
color = colors[idx % len(colors)]
|
138 |
+
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
|
139 |
+
middle_save_path = "contours_colored_result.png"
|
140 |
+
cv2.imwrite(middle_save_path, image)
|
141 |
+
print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
|
142 |
+
if image is not None:
|
143 |
+
# 找到掩码的边界
|
144 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
145 |
+
for contour in contours:
|
146 |
+
x, y, w, h = cv2.boundingRect(contour)
|
147 |
+
# print(x, y, w, h)
|
148 |
+
if w > 50 and h > 50:
|
149 |
+
for size in range(90,40,-5):
|
150 |
+
for _ in range(100):
|
151 |
+
random_x1 = random.randint(x, x + w - 50)
|
152 |
+
random_y1 = random.randint(y, y + h - 50)
|
153 |
+
random_x2 = random_x1 - size
|
154 |
+
random_y2 = random_y1 - size
|
155 |
+
# print(random_x1, random_y1,random_x2,random_y2)
|
156 |
+
# 在原图上绘制矩形框
|
157 |
+
# 保存结果图像
|
158 |
+
try:
|
159 |
+
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
|
160 |
+
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
161 |
+
cv2.imwrite(save_path, image)
|
162 |
+
# generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
|
163 |
+
print(f"Image with rectangle saved at {save_path}")
|
164 |
+
return (random_x1,random_y1),(random_x2,random_y2)
|
165 |
+
except:
|
166 |
+
pass
|
167 |
+
# cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
168 |
+
# cv2.imwrite(save_path, image)
|
169 |
+
# print(f"Image with rectangle saved at {save_path}")
|
170 |
+
# break
|
171 |
+
for _ in range(100):
|
172 |
+
random_x1 = random.randint(x, x + w - 50)
|
173 |
+
random_y1 = random.randint(y, y + h - 50)
|
174 |
+
random_x2 = random_x1 + size
|
175 |
+
random_y2 = random_y1 + size
|
176 |
+
# print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
|
177 |
+
# 在原图上绘制矩形框
|
178 |
+
# 保存结果图像
|
179 |
+
try:
|
180 |
+
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
|
181 |
+
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
182 |
+
cv2.imwrite(save_path, image)
|
183 |
+
print(f"Image with rectangle saved at {save_path}")
|
184 |
+
# generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
|
185 |
+
return (random_x1,random_y1),(random_x2,random_y2)
|
186 |
+
except:
|
187 |
+
pass
|
188 |
+
|
189 |
+
ax.imshow(mask_image)
|
190 |
+
plt.axis('off')
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
def attention_mask(mask, ax, image_path,strategy="LOA",random_color=False, borders=True, image=None, save_path=None):
|
195 |
+
"""
|
196 |
+
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
|
197 |
+
|
198 |
+
参数:
|
199 |
+
- `mask`: 掩码区域
|
200 |
+
- `ax`: 用于绘制的matplotlib轴
|
201 |
+
- `random_color`: 是否使用随机颜色
|
202 |
+
- `borders`: 是否显示边界
|
203 |
+
- `image`: 原始图像,用于绘制矩形框
|
204 |
+
- `save_path`: 保存结果图像的路径
|
205 |
+
"""
|
206 |
+
if random_color:
|
207 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
208 |
+
else:
|
209 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
210 |
+
orig_w, orig_h = image.shape[1],image.shape[0]
|
211 |
+
# print(image.shape)
|
212 |
+
h, w = mask.shape[-2:]
|
213 |
+
mask = mask.astype(np.uint8)
|
214 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
215 |
+
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
|
216 |
+
print("原始二值掩码已保存为 binary_mask.png")
|
217 |
+
# if borders:
|
218 |
+
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
219 |
+
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
220 |
+
# mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
221 |
+
# colors = [
|
222 |
+
# (255, 0, 0), # 红色
|
223 |
+
# (0, 255, 0), # 绿色
|
224 |
+
# (0, 0, 255), # 蓝色
|
225 |
+
# (255, 255, 0), # 黄色
|
226 |
+
# (255, 0, 255), # 品红色
|
227 |
+
# (0, 255, 255), # 青色
|
228 |
+
# (255, 128, 0), # 橙色
|
229 |
+
# (128, 0, 255), # 紫色
|
230 |
+
# (128, 128, 128), # 灰色
|
231 |
+
# (0, 128, 0) # 深绿色
|
232 |
+
# ]
|
233 |
+
# # print(mask.shape)
|
234 |
+
# for idx, contour in enumerate(contours):
|
235 |
+
# x, y, w, h = cv2.boundingRect(contour)
|
236 |
+
# print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
|
237 |
+
# color = colors[idx % len(colors)]
|
238 |
+
# cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
|
239 |
+
# middle_save_path = "contours_colored_result.png"
|
240 |
+
# cv2.imwrite(middle_save_path, image)
|
241 |
+
# print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
|
242 |
+
candidates = []
|
243 |
+
path = image_path
|
244 |
+
cls_heatmap = heatmap_dict[path]['cls_heatmap']
|
245 |
+
reg_heatmap = heatmap_dict[path]['reg_heatmap']
|
246 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
247 |
+
if strategy == "LDA":
|
248 |
+
combined = cls_heatmap.astype(np.float32)
|
249 |
+
if strategy == "LOA" or strategy == "LRA":
|
250 |
+
combined = reg_heatmap.astype(np.float32)
|
251 |
+
print(mask.shape)
|
252 |
+
mask = cv2.resize(mask, (combined.shape[1], combined.shape[0]), interpolation=cv2.INTER_NEAREST)
|
253 |
+
mask = (mask > 0.5).astype(np.uint8)
|
254 |
+
cv2.imwrite("crop_binary_mask.png", (mask * 255).astype(np.uint8))
|
255 |
+
print("处理后的裁剪二值掩码已保存为 crop_binary_mask.png")
|
256 |
+
print(combined.shape)
|
257 |
+
vis_image = cv2.imread(image_path)
|
258 |
+
vis_image = cv2.resize(vis_image,(combined.shape[1],combined.shape[0]))
|
259 |
+
mask_image = cv2.resize(mask_image,(combined.shape[1],combined.shape[0]))
|
260 |
+
image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
|
261 |
+
if borders:
|
262 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
263 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
264 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
265 |
+
colors = [
|
266 |
+
(255, 0, 0), # 红色
|
267 |
+
(0, 255, 0), # 绿色
|
268 |
+
(0, 0, 255), # 蓝色
|
269 |
+
(255, 255, 0), # 黄色
|
270 |
+
(255, 0, 255), # 品红色
|
271 |
+
(0, 255, 255), # 青色
|
272 |
+
(255, 128, 0), # 橙色
|
273 |
+
(128, 0, 255), # 紫色
|
274 |
+
(128, 128, 128), # 灰色
|
275 |
+
(0, 128, 0) # 深绿色
|
276 |
+
]
|
277 |
+
# print(mask.shape)
|
278 |
+
for idx, contour in enumerate(contours):
|
279 |
+
x, y, w, h = cv2.boundingRect(contour)
|
280 |
+
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
|
281 |
+
color = colors[idx % len(colors)]
|
282 |
+
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
|
283 |
+
middle_save_path = "contours_colored_result.png"
|
284 |
+
cv2.imwrite(middle_save_path, image)
|
285 |
+
print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
|
286 |
+
if image is not None:
|
287 |
+
# 找到掩码的边界
|
288 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
289 |
+
# print(contours)
|
290 |
+
for contour in contours:
|
291 |
+
x, y, w, h = cv2.boundingRect(contour)
|
292 |
+
print("the contour is:",x, y, w, h)
|
293 |
+
if w > 50 and h > 50:
|
294 |
+
for size in range(50,40,-5):
|
295 |
+
for y_step in range(y, y+h - size,5):
|
296 |
+
for x_step in range(x, x+w - size,5):
|
297 |
+
x1, y1, x2, y2 = x_step, y_step, x_step + size, y_step + size
|
298 |
+
# print(mask[y1:y2, x1:x2].sum())
|
299 |
+
if mask[y1:y2, x1:x2].sum() >= size * size: # 掩码区域必须都在内部
|
300 |
+
heat_value = combined[y1:y2, x1:x2].mean()
|
301 |
+
# print("the heat_value is:",heat_value,y1,y2, x1,x2,combined.shape)
|
302 |
+
if not math.isnan(heat_value):
|
303 |
+
candidates.append(((x1, y1, x2, y2), heat_value))
|
304 |
+
cv2.rectangle(vis_image, (x1, y1), (x2, y2), (0, 255, 0), 1)
|
305 |
+
cv2.putText(vis_image, f'{heat_value:.1f}', (x1, y1 - 2), font, 0.4, (0, 0, 255), 1)
|
306 |
+
if not candidates:
|
307 |
+
print("⚠️ 没有找到满足掩码内区域的候选框")
|
308 |
+
else:
|
309 |
+
break
|
310 |
+
cv2.imwrite("attention_vis.jpg", vis_image)
|
311 |
+
print(f"Attention 候选框可视化已保存 attention_vis.jpg")
|
312 |
+
# 从高到低排序,选择热值最高的
|
313 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
314 |
+
print(save_path,candidates[0],candidates[-1])
|
315 |
+
for (x1, y1, x2, y2), _ in candidates:
|
316 |
+
try:
|
317 |
+
if mask[y1, x1] == 1 and mask[y2, x2] == 1:
|
318 |
+
# 可视化 + 保存
|
319 |
+
if save_path:
|
320 |
+
image = cv2.imread(image_path)
|
321 |
+
image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
|
322 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
323 |
+
# os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
324 |
+
cv2.imwrite(save_path, image)
|
325 |
+
print(f"Image with rectangle saved at {save_path}")
|
326 |
+
resize_w, resize_h = combined.shape[1],combined.shape[0]
|
327 |
+
scale_x = orig_w / resize_w
|
328 |
+
scale_y = orig_h / resize_h
|
329 |
+
x1_orig = int(x1 * scale_x)
|
330 |
+
x2_orig = int(x2 * scale_x)
|
331 |
+
y1_orig = int(y1 * scale_y)
|
332 |
+
y2_orig = int(y2 * scale_y)
|
333 |
+
cx = (x1_orig + x2_orig) // 2
|
334 |
+
cy = (y1_orig + y2_orig) // 2
|
335 |
+
target_size = 90
|
336 |
+
half = target_size // 2
|
337 |
+
x1_exp = max(0, cx - half)
|
338 |
+
y1_exp = max(0, cy - half)
|
339 |
+
x2_exp = min(orig_w - 1, cx + half)
|
340 |
+
y2_exp = min(orig_h - 1, cy + half)
|
341 |
+
print(f"扩展后的原图坐标: ({x1_exp}, {y1_exp}), ({x2_exp}, {y2_exp})")
|
342 |
+
image_full = cv2.imread(image_path) # 原图大小读取
|
343 |
+
cv2.rectangle(image_full, (x1_exp, y1_exp), (x2_exp, y2_exp), (0, 0, 255), 2)
|
344 |
+
cv2.imwrite("expanded_bbox_on_original.jpg", image_full)
|
345 |
+
print("📌 扩大后的候选框已绘制到原图并保存为 expanded_bbox_on_original.jpg")
|
346 |
+
return (x1_exp, y1_exp), (x2_exp, y2_exp)
|
347 |
+
except Exception as e:
|
348 |
+
print("the error is:",e)
|
349 |
+
pass # 若越界等问题,继续下一个
|
350 |
+
# for _ in range(100):
|
351 |
+
# random_x1 = random.randint(x, x + w - 50)
|
352 |
+
# random_y1 = random.randint(y, y + h - 50)
|
353 |
+
# random_x2 = random_x1 + size
|
354 |
+
# random_y2 = random_y1 + size
|
355 |
+
# # print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
|
356 |
+
# try:
|
357 |
+
# if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
|
358 |
+
# cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
359 |
+
# cv2.imwrite(save_path, image)
|
360 |
+
# print(f"Image with rectangle saved at {save_path}")
|
361 |
+
# return (random_x1,random_y1),(random_x2,random_y2)
|
362 |
+
# except:
|
363 |
+
# pass
|
364 |
+
|
365 |
+
ax.imshow(mask_image)
|
366 |
+
plt.axis('off')
|
367 |
+
|
368 |
+
def generate_gt_mask_from_intersection(random_rectangle, yolo_boxes, image, mask_img,sam2_model, threshold_iou):
|
369 |
+
"""
|
370 |
+
判断随机生成的矩形与YOLO的框是否足够接近,
|
371 |
+
若满足条件则调用SAM获取精准掩码作为GT。
|
372 |
+
"""
|
373 |
+
image_np = np.array(image)
|
374 |
+
x1_rect, y1_rect = random_rectangle[0]
|
375 |
+
x2_rect, y2_rect = random_rectangle[1]
|
376 |
+
rect_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
|
377 |
+
cv2.rectangle(rect_mask, (x1_rect, y1_rect), (x2_rect, y2_rect), color=255, thickness=-1)
|
378 |
+
|
379 |
+
rect_box = [min(x1_rect, x2_rect), min(y1_rect, y2_rect), max(x1_rect, x2_rect), max(y1_rect, y2_rect)]
|
380 |
+
|
381 |
+
for box in yolo_boxes:
|
382 |
+
iou = calculate_iou(rect_box, box)
|
383 |
+
print(f"与YOLO box的IoU为: {iou}, 阈值: {threshold_iou}")
|
384 |
+
|
385 |
+
if iou >= threshold_iou:
|
386 |
+
# 在YOLO框内随机取两个点
|
387 |
+
x_min, y_min, x_max, y_max = box
|
388 |
+
input_point1 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
|
389 |
+
input_point2 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
|
390 |
+
input_point3 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
|
391 |
+
|
392 |
+
# 使用SAM生成精准掩码
|
393 |
+
gt_mask = get_gt_mask_from_sam(image, sam2_model, [input_point1, input_point2,input_point3], rect_mask)
|
394 |
+
mask_img[gt_mask > 0] = 0
|
395 |
+
# 保存gt掩码
|
396 |
+
cv2.imwrite('gt_mask_from_sam.png', gt_mask)
|
397 |
+
print(f"SAM生成的GT掩码已保存至 gt_mask_from_sam.png")
|
398 |
+
|
399 |
+
return gt_mask,mask_img
|
400 |
+
h, w = image_np.shape[:2]
|
401 |
+
black_mask = np.zeros((h, w), dtype=np.uint8)
|
402 |
+
no_match_save_path = 'gt_mask_from_sam.png'
|
403 |
+
cv2.imwrite(no_match_save_path, black_mask)
|
404 |
+
print("未找到满足阈值条件的YOLO box。")
|
405 |
+
print(f"未匹配成功,保存空掩码图至 {no_match_save_path}")
|
406 |
+
return None,mask_img
|
407 |
+
|
408 |
+
def calculate_iou(boxA, boxB):
|
409 |
+
"""计算两个box的IoU."""
|
410 |
+
xA = max(boxA[0], boxB[0])
|
411 |
+
yA = max(boxA[1], boxB[1])
|
412 |
+
xB = min(boxA[2], boxB[2])
|
413 |
+
yB = min(boxA[3], boxB[3])
|
414 |
+
|
415 |
+
inter_area = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
416 |
+
|
417 |
+
boxA_area = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
418 |
+
boxB_area = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
419 |
+
|
420 |
+
iou = inter_area / float(boxA_area + boxB_area - inter_area)
|
421 |
+
return iou
|
422 |
+
|
423 |
+
def get_gt_mask_from_sam(image, sam2_model, input_points, rect_mask):
|
424 |
+
"""使用SAM根据两个点生成掩码,并保存选取点和掩码图"""
|
425 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
426 |
+
print("load sam2")
|
427 |
+
predictor.set_image(image)
|
428 |
+
|
429 |
+
input_point_np = np.array(input_points)
|
430 |
+
input_label = np.array([1, 1,1])
|
431 |
+
|
432 |
+
masks, _, _ = predictor.predict(
|
433 |
+
point_coords=input_point_np,
|
434 |
+
point_labels=input_label,
|
435 |
+
multimask_output=False,
|
436 |
+
)
|
437 |
+
|
438 |
+
mask_img = masks[0].astype(np.uint8) * 255
|
439 |
+
# mask_img[rect_mask == 255] = 0 # 将 `random_rectangle` 区域设为黑色
|
440 |
+
|
441 |
+
# 保存SAM生成的掩码图
|
442 |
+
mask_save_path = 'sam_gt_mask.jpg'
|
443 |
+
cv2.imwrite(mask_save_path, mask_img)
|
444 |
+
print(f"SAM生成的掩码已保存至 {mask_save_path}")
|
445 |
+
|
446 |
+
# 把选取的两个点画在原图上
|
447 |
+
image_with_points = np.array(image).copy()
|
448 |
+
for point in input_points:
|
449 |
+
cv2.circle(image_with_points, point, radius=5, color=(255, 0, 0), thickness=-1)
|
450 |
+
|
451 |
+
# 保存带有标记点的原图
|
452 |
+
point_marked_save_path = 'image_with_points.jpg'
|
453 |
+
image_bgr = cv2.cvtColor(image_with_points, cv2.COLOR_RGB2BGR)
|
454 |
+
cv2.imwrite(point_marked_save_path, image_bgr)
|
455 |
+
print(f"带点标记的原图已保存至 {point_marked_save_path}")
|
456 |
+
|
457 |
+
return mask_img
|
458 |
+
|
459 |
+
def show_points(coords, labels, ax, marker_size=375):
|
460 |
+
pos_points = coords[labels==1]
|
461 |
+
neg_points = coords[labels==0]
|
462 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
463 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
464 |
+
|
465 |
+
def show_box(box, ax):
|
466 |
+
x0, y0 = box[0], box[1]
|
467 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
468 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
469 |
+
|
470 |
+
def display_mask(mask, ax, random_color=False, borders = True):
|
471 |
+
if random_color:
|
472 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
473 |
+
else:
|
474 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
475 |
+
h, w = mask.shape[-2:]
|
476 |
+
mask = mask.astype(np.uint8)
|
477 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
478 |
+
if borders:
|
479 |
+
import cv2
|
480 |
+
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
481 |
+
# Try to smooth contours
|
482 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
483 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
484 |
+
cv2.imwrite("check.jpg", mask_image)
|
485 |
+
ax.imshow(mask_image)
|
486 |
+
|
487 |
+
def random_points_below(point, radius, min_distance, model, image, max_attempts=100):
|
488 |
+
"""
|
489 |
+
在给定的point偏下方50像素的区域内,随机选择两个点直到满足条件。
|
490 |
+
|
491 |
+
参数:
|
492 |
+
- point: (x, y) 格式的坐标
|
493 |
+
- radius: 随机点的最大半径
|
494 |
+
- min_distance: 两个随机点之间的最小距离
|
495 |
+
- max_attempts: 最大尝试次数,避免死循环
|
496 |
+
|
497 |
+
返回:
|
498 |
+
- 两个随机点的坐标,如果没有找到合适的点则返回None
|
499 |
+
"""
|
500 |
+
for _ in range(max_attempts):
|
501 |
+
# 在点的偏下方50像素区域内随机选择两个点
|
502 |
+
x1 = random.randint(point[0] - radius, point[0] + radius)
|
503 |
+
y1 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
|
504 |
+
|
505 |
+
x2 = random.randint(point[0] - radius, point[0] + radius)
|
506 |
+
y2 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
|
507 |
+
|
508 |
+
# 计算两个点之间的欧几里得距离
|
509 |
+
distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
|
510 |
+
|
511 |
+
# 检查距离条件
|
512 |
+
if distance >= min_distance and is_point_in_car_area((x1, y1), model, image) and is_point_in_car_area((x2, y2), model, image) :
|
513 |
+
return [(x1, y1), (x2, y2)]
|
514 |
+
|
515 |
+
# 如果超过最大尝试次数还没有找到合适的点,返回None
|
516 |
+
return None
|
517 |
+
|
518 |
+
|
519 |
+
def show_masks(image, masks, scores, image_path, strategy,point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None):
|
520 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
521 |
+
plt.figure(figsize=(10, 10))
|
522 |
+
plt.imshow(image)
|
523 |
+
display_mask(mask, plt.gca(), borders=borders)
|
524 |
+
if point_coords is not None:
|
525 |
+
assert input_labels is not None
|
526 |
+
show_points(point_coords, input_labels, plt.gca())
|
527 |
+
if box_coords is not None:
|
528 |
+
# boxes
|
529 |
+
show_box(box_coords, plt.gca())
|
530 |
+
plt.axis('off')
|
531 |
+
plt.savefig('check.jpg', bbox_inches='tight', pad_inches=0) # 保存图像
|
532 |
+
point1,point2 = attention_mask(mask, plt.gca(), image_path,strategy,borders=borders, image=image, save_path=save_path)
|
533 |
+
return point1,point2
|
534 |
+
|
535 |
+
def random_crop(image, target_width, target_height, mask_point1, mask_point2):
|
536 |
+
# global global_mask_point1_relative, global_mask_point2_relative
|
537 |
+
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
|
538 |
+
width, height = image.size
|
539 |
+
# 计算两个对角点的中点
|
540 |
+
center_x = (mask_point1[0] + mask_point2[0]) // 2
|
541 |
+
center_y = (mask_point1[1] + mask_point2[1]) // 2
|
542 |
+
|
543 |
+
# 计算裁剪区域的左上角和右下角
|
544 |
+
left = center_x - target_width // 2
|
545 |
+
top = center_y - target_height // 2
|
546 |
+
right = left + target_width
|
547 |
+
bottom = top + target_height
|
548 |
+
|
549 |
+
# 确保裁剪区域不会超出图像边界
|
550 |
+
if left < 0:
|
551 |
+
left = 0
|
552 |
+
right = target_width
|
553 |
+
if top < 0:
|
554 |
+
top = 0
|
555 |
+
bottom = target_height
|
556 |
+
if right > width:
|
557 |
+
right = width
|
558 |
+
left = width - target_width
|
559 |
+
if bottom > height:
|
560 |
+
bottom = height
|
561 |
+
top = height - target_height
|
562 |
+
|
563 |
+
# 计算 padding
|
564 |
+
top_padding = max(0, top)
|
565 |
+
left_padding = max(0, left)
|
566 |
+
|
567 |
+
# 裁剪图像
|
568 |
+
cropped_image = image.crop((left, top, right, bottom))
|
569 |
+
|
570 |
+
global_mask_point1_relative = (mask_point1[0] - left, mask_point1[1] - top)
|
571 |
+
global_mask_point2_relative = (mask_point2[0] - left, mask_point2[1] - top)
|
572 |
+
print("裁剪后点的相对位置为:")
|
573 |
+
print("mask_point1:", global_mask_point1_relative)
|
574 |
+
print("mask_point2:", global_mask_point2_relative)
|
575 |
+
return cropped_image, top_padding, left_padding,global_mask_point1_relative,global_mask_point2_relative
|
576 |
+
|
577 |
+
def get_left_right_points(lane_data,image_path):
|
578 |
+
lanes = lane_data["lanes"]
|
579 |
+
h_samples = lane_data["h_samples"]
|
580 |
+
model = load_yolov5_model()
|
581 |
+
# 找到h_samples的中间索引
|
582 |
+
mid_idx = len(h_samples) // 2
|
583 |
+
image = cv2.imread(image_path)
|
584 |
+
# 存储最左和最右的点
|
585 |
+
left_point = None
|
586 |
+
right_point = None
|
587 |
+
points = []
|
588 |
+
# 遍历每条车道线
|
589 |
+
for lane in lanes:
|
590 |
+
# 去掉值为-2的无效点
|
591 |
+
valid_points = [(x, y) for x, y in zip(lane, h_samples) if x != -2]
|
592 |
+
|
593 |
+
if valid_points:
|
594 |
+
if lane[mid_idx] != -2:
|
595 |
+
for i in range(mid_idx-2,0,-1):
|
596 |
+
left_point = lane[i]
|
597 |
+
print(left_point)
|
598 |
+
if lane[i] != -2:
|
599 |
+
point = (left_point,h_samples[i])
|
600 |
+
FLAG = is_point_in_car_area(point, model, image)
|
601 |
+
print(point,FLAG)
|
602 |
+
if FLAG:
|
603 |
+
points.append((left_point,h_samples[i]))
|
604 |
+
break
|
605 |
+
else:
|
606 |
+
point = (1540/2, 590/2+30) # 初始点坐标
|
607 |
+
radius = 50 # 随机点的最大半径
|
608 |
+
min_distance = 40 # 两个点之间的最小距离
|
609 |
+
points = random_points_below(point, radius, min_distance,model,image)
|
610 |
+
# first_non_minus_two = next((x for x in lane if x != -2), None)
|
611 |
+
# if first_non_minus_two:
|
612 |
+
# idx = lane.index(first_non_minus_two)
|
613 |
+
# for i in range(idx+5,idx,-1):
|
614 |
+
# left_point = lane[i]
|
615 |
+
# if lane[i] != -2:
|
616 |
+
# point = (left_point,h_samples[i])
|
617 |
+
# FLAG = is_point_in_car_area(point, model, image)
|
618 |
+
# if FLAG:
|
619 |
+
# points.append((left_point,h_samples[i]))
|
620 |
+
# break
|
621 |
+
|
622 |
+
# return left_point, right_point
|
623 |
+
return points
|
624 |
+
|
625 |
+
def sam2segment(image_path,points,strategy):
|
626 |
+
# print(points)
|
627 |
+
image = Image.open(image_path)
|
628 |
+
image = np.array(image.convert("RGB"))
|
629 |
+
predictor.set_image(image)
|
630 |
+
# print([points[0][0], points[0][1]])
|
631 |
+
input_point = np.array([(points[0][0], points[0][1])])
|
632 |
+
input_label = np.array([1])
|
633 |
+
masks, scores, logits = predictor.predict(
|
634 |
+
point_coords=input_point,
|
635 |
+
point_labels=input_label,
|
636 |
+
multimask_output=True,
|
637 |
+
)
|
638 |
+
sorted_ind = np.argsort(scores)[::-1]
|
639 |
+
masks = masks[sorted_ind]
|
640 |
+
scores = scores[sorted_ind]
|
641 |
+
logits = logits[sorted_ind]
|
642 |
+
#mask
|
643 |
+
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
644 |
+
points_set = []
|
645 |
+
for point in points:
|
646 |
+
points_set.append((point[0], point[1]))
|
647 |
+
# print(points_set)
|
648 |
+
input_point = np.array(points_set)
|
649 |
+
input_label = np.array([1]*len(points_set))
|
650 |
+
masks, scores, _ = predictor.predict(
|
651 |
+
point_coords=input_point,
|
652 |
+
point_labels=input_label,
|
653 |
+
mask_input=mask_input[None, :, :],
|
654 |
+
multimask_output=False,
|
655 |
+
)
|
656 |
+
# random_mask_selection(image, masks, mask_index=0,output_path="cropped_image.jpg")
|
657 |
+
point1,point2 = show_masks(image, masks, scores, image_path, strategy,save_path="masked_image.jpg")
|
658 |
+
return point1,point2
|
659 |
+
|
660 |
+
def draw_point(image_path,points):
|
661 |
+
image = cv2.imread(image_path)
|
662 |
+
if image is not None:
|
663 |
+
# 绘制点
|
664 |
+
for point in points:
|
665 |
+
cv2.circle(image, point, radius=5, color=(0, 255, 0), thickness=-1) # 绿色点
|
666 |
+
|
667 |
+
# 保存图像
|
668 |
+
output_path = "output_image_with_points.jpg"
|
669 |
+
cv2.imwrite(output_path, image)
|
670 |
+
print(f"Image saved with points at {output_path}")
|
671 |
+
else:
|
672 |
+
print("Error: Image could not be loaded.")
|
673 |
+
|
674 |
+
def generate_mask(original_img_path, point1, point2):
|
675 |
+
"""根据坐标生成掩码图像"""
|
676 |
+
# 读取原图
|
677 |
+
original_img = cv2.imread(original_img_path)
|
678 |
+
|
679 |
+
# 获取原图的尺寸
|
680 |
+
height, width, _ = original_img.shape
|
681 |
+
|
682 |
+
# 创建一个黑色的 mask 图像,尺寸与原图相同
|
683 |
+
mask = np.zeros((height, width), dtype=np.uint8)
|
684 |
+
|
685 |
+
# 计算3/4点
|
686 |
+
three_quarter_point = (
|
687 |
+
int(point1[0] + 0.95 * (point2[0] - point1[0])), # 计算 x 坐标
|
688 |
+
int(point1[1] + 0.95 * (point2[1] - point1[1])) # 计算 y 坐标
|
689 |
+
)
|
690 |
+
|
691 |
+
# 画出一个白色的矩形(将该区域填充为白色)
|
692 |
+
cv2.rectangle(mask, point1, three_quarter_point, color=255, thickness=-1)
|
693 |
+
|
694 |
+
# 保存生成的mask图像
|
695 |
+
mask_path = original_img_path.replace('test.jpg', 'mask_test.jpg')
|
696 |
+
cv2.imwrite(mask_path, mask)
|
697 |
+
print(mask_path)
|
698 |
+
return mask_path, point1, three_quarter_point
|
699 |
+
|
700 |
+
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
|
701 |
+
"""
|
702 |
+
过滤 TuSimple `lanes`,只保留 `crop` 内的部分
|
703 |
+
"""
|
704 |
+
cropped_lanes = []
|
705 |
+
for lane in lane_data["lanes"]:
|
706 |
+
cropped_lane = []
|
707 |
+
for x, y in zip(lane, lane_data["h_samples"]):
|
708 |
+
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
|
709 |
+
cropped_lane.append((x, y))
|
710 |
+
# new_x = x - crop_x_min
|
711 |
+
# new_y = y - crop_y_min
|
712 |
+
# cropped_lane.append((new_x, new_y))
|
713 |
+
if cropped_lane:
|
714 |
+
cropped_lanes.append(cropped_lane)
|
715 |
+
|
716 |
+
return cropped_lanes
|
717 |
+
|
718 |
+
|
719 |
+
def generate_trigger_crop(image_path: str, lane_data: dict):
|
720 |
+
"""
|
721 |
+
输入一张图像路径,返回处理后的 crop 图像和 crop mask 图像路径。
|
722 |
+
"""
|
723 |
+
# 1. 获取触发点
|
724 |
+
points = get_left_right_points(lane_data, image_path)
|
725 |
+
print(f"[INFO] 获取 trigger 点: {points}")
|
726 |
+
draw_point(image_path, points)
|
727 |
+
|
728 |
+
# 2. 使用 SAM2 获取 mask 点
|
729 |
+
image = load_image(image_path)
|
730 |
+
mask_point1, mask_point2 = sam2segment(image_path, points, "LDA")
|
731 |
+
|
732 |
+
# 3. Crop 原图
|
733 |
+
input_image, *_ = random_crop(image, 512, 512, mask_point1, mask_point2)
|
734 |
+
input_crop_path = "crop.jpg"
|
735 |
+
input_image.save(input_crop_path)
|
736 |
+
|
737 |
+
# 4. 生成 trigger mask
|
738 |
+
mask_path, point1, point2 = generate_mask(image_path, mask_point1, mask_point2)
|
739 |
+
mask_img = load_image(mask_path)
|
740 |
+
mask_img, *_ = random_crop(mask_img, 512, 512, mask_point1, mask_point2)
|
741 |
+
crop_mask_path = "crop_mask.jpg"
|
742 |
+
cv2.imwrite(crop_mask_path, np.array(mask_img))
|
743 |
+
|
744 |
+
return input_crop_path, crop_mask_path
|
745 |
+
|
746 |
+
if __name__ == "__main__":
|
747 |
+
lane_data = {"lanes": [[-2, -2, -2, -2, -2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, 818, 801, 784, 768, 751, 734, 717, 701, 685, 668, 651, 634, 618, 601, 585, 568, 551, 535, 518, 502, 484, 468, 451, 435, 418, 401, 385, 368, 351, 335, 318, 301, 287], [-2, -2, -2, -2, -2, -2, -2, 863, 872, 881, 890, 899, 908, 918, 927, 936, 945, 954, 964, 972, 982, 991, 1000, 1009, 1018, 1027, 1036, 1046, 1055, 1064, 1073, 1082, 1091, 1100, 1109, 1119, 1128, 1137, 1146, 1154]], "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590], "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"}
|
748 |
+
|
749 |
+
image_path = "driver_182_30frame/06010513_0036.MP4/00270.jpg"
|
750 |
+
points = get_left_right_points(lane_data,image_path)
|
751 |
+
print(points)
|
752 |
+
draw_point(image_path,points)
|
753 |
+
# left_point, right_point = get_left_right_points(lane_data)
|
754 |
+
# print(f"Left point: {left_point}, Right point: {right_point}")
|
755 |
+
# sam2segment(image_path,left_point, right_point)
|
756 |
+
image = load_image(image_path)
|
757 |
+
mask_point1,mask_point2 = sam2segment(image_path,points,"LDA")
|
758 |
+
input_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(image, 512, 512, mask_point1, mask_point2)
|
759 |
+
input_image.save("crop.jpg") # 直接用 PIL 的 `save()` 方法
|
760 |
+
print(f"Image saved with points at crop.jpg")
|
761 |
+
mask_path, point1, point2 = generate_mask('culane_test.jpg', mask_point1, mask_point2)
|
762 |
+
mask_img = load_image(mask_path)
|
763 |
+
mask_img,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(mask_img, 512, 512,mask_point1,mask_point2)
|
764 |
+
|
765 |
+
mask_img = np.array(mask_img)
|
766 |
+
# print(mask_img.shape)
|
767 |
+
model = load_yolov5_model()
|
768 |
+
yolo_results = model(input_image)
|
769 |
+
yolo_boxes = []
|
770 |
+
car_class_id = [2, 5, 7] # 汽车、巴士、卡车等类别ID,根据实际情况调整
|
771 |
+
|
772 |
+
for result in yolo_results:
|
773 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
774 |
+
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
775 |
+
|
776 |
+
for box, cls in zip(boxes, class_ids):
|
777 |
+
if cls in car_class_id:
|
778 |
+
x_min, y_min, x_max, y_max = box[:4]
|
779 |
+
yolo_boxes.append([int(x_min), int(y_min), int(x_max), int(y_max)])
|
780 |
+
_,mask_img=generate_gt_mask_from_intersection([global_mask_point1_relative,global_mask_point2_relative], yolo_boxes, input_image, mask_img,sam2_model, threshold_iou=0.01)
|
781 |
+
cv2.imwrite("crop_mask.jpg", mask_img)
|
782 |
+
|
783 |
+
print("Mask 已成功保存至 crop_mask.jpg")
|
784 |
+
crop_x_min = min(mask_point1[0], mask_point2[0])
|
785 |
+
crop_x_max = max(mask_point1[0], mask_point2[0])
|
786 |
+
crop_y_min = min(mask_point1[1], mask_point2[1])
|
787 |
+
crop_y_max = max(mask_point1[1], mask_point2[1])
|
788 |
+
|
789 |
+
|
790 |
+
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
|
791 |
+
"""
|
792 |
+
过滤 TuSimple `lanes`,只保留 `crop` 内的部分
|
793 |
+
"""
|
794 |
+
cropped_lanes = []
|
795 |
+
for lane in lane_data["lanes"]:
|
796 |
+
cropped_lane = []
|
797 |
+
for x, y in zip(lane, lane_data["h_samples"]):
|
798 |
+
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
|
799 |
+
cropped_lane.append((x, y))
|
800 |
+
# new_x = x - crop_x_min
|
801 |
+
# new_y = y - crop_y_min
|
802 |
+
# cropped_lane.append((new_x, new_y))
|
803 |
+
if cropped_lane:
|
804 |
+
cropped_lanes.append(cropped_lane)
|
805 |
+
|
806 |
+
return cropped_lanes
|
807 |
+
|
808 |
+
# **获取在 crop 范围内的 lane**
|
809 |
+
cropped_lanes = extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max)
|
810 |
+
# print(cropped_lanes)
|
811 |
+
# def draw_lane_mask(image, lanes):
|
812 |
+
# """
|
813 |
+
# 画出 `lane_mask` 只在 `crop` 图像中
|
814 |
+
# """
|
815 |
+
# height, width, _ = image.shape
|
816 |
+
# lane_mask = np.zeros((height, width), dtype=np.uint8)
|
817 |
+
|
818 |
+
# for lane in lanes:
|
819 |
+
# points = np.array(lane, dtype=np.int32)
|
820 |
+
# cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=5)
|
821 |
+
|
822 |
+
# return lane_mask
|
823 |
+
|
824 |
+
# crop_image = load_image("crop.jpg").convert("RGB")
|
825 |
+
# crop_image = np.array(crop_image)
|
826 |
+
# lane_mask = draw_lane_mask(crop_image, cropped_lanes)
|
827 |
+
def draw_lane_mask_on_original(image, cropped_lanes):
|
828 |
+
"""
|
829 |
+
在原图上绘制 **仅包含 cropped_lanes** 的车道线
|
830 |
+
"""
|
831 |
+
height, width, _ = image.shape
|
832 |
+
lane_mask = np.zeros((height, width), dtype=np.uint8)
|
833 |
+
|
834 |
+
for lane in cropped_lanes:
|
835 |
+
points = np.array(lane, dtype=np.int32)
|
836 |
+
cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=10)
|
837 |
+
|
838 |
+
return lane_mask
|
839 |
+
|
840 |
+
def random_crop_lane(image, target_width, target_height, mask_point1, mask_point2):
|
841 |
+
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
|
842 |
+
|
843 |
+
# **确保 image 是 NumPy 数组**
|
844 |
+
if isinstance(image, Image.Image):
|
845 |
+
image = np.array(image)
|
846 |
+
|
847 |
+
height, width = image.shape[:2] # 获取 NumPy 数组的大小
|
848 |
+
|
849 |
+
# 计算两个对角点的中点
|
850 |
+
center_x = (mask_point1[0] + mask_point2[0]) // 2
|
851 |
+
center_y = (mask_point1[1] + mask_point2[1]) // 2
|
852 |
+
|
853 |
+
# 计算裁剪区域的左上角和右下角
|
854 |
+
left = max(0, center_x - target_width // 2)
|
855 |
+
top = max(0, center_y - target_height // 2)
|
856 |
+
right = min(width, left + target_width)
|
857 |
+
bottom = min(height, top + target_height)
|
858 |
+
|
859 |
+
# 计算 padding(如果裁剪区域超出边界)
|
860 |
+
top_padding = max(0, target_height - (bottom - top))
|
861 |
+
left_padding = max(0, target_width - (right - left))
|
862 |
+
|
863 |
+
# **使用 NumPy 进行裁剪**
|
864 |
+
cropped_image = image[top:bottom, left:right]
|
865 |
+
|
866 |
+
return cropped_image, top_padding, left_padding
|
867 |
+
# **绘制 lane_mask 在原图上**
|
868 |
+
raw_image = np.array(load_image(image_path).convert("RGB"))
|
869 |
+
lane_mask = draw_lane_mask_on_original(raw_image, cropped_lanes)
|
870 |
+
lane_mask_pil = Image.fromarray(lane_mask)
|
871 |
+
crop_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(lane_mask_pil, 512, 512,mask_point1,mask_point2)
|
872 |
+
|
873 |
+
# **保存 lane_mask**
|
874 |
+
crop_image.save("lane_mask_crop.jpg")
|
875 |
+
print("✅ 车道 Mask 已保存为 lane_mask_crop.jpg")
|
876 |
+
|
877 |
+
crop_img = cv2.imread("crop.jpg") # 读取原图(BGR格式)
|
878 |
+
mask_img = cv2.imread("crop_mask.jpg", cv2.IMREAD_GRAYSCALE) # 读取掩码(灰度图)
|
879 |
+
if crop_img.shape[:2] != mask_img.shape:
|
880 |
+
print("⚠️ Resizing mask to match crop image size...")
|
881 |
+
mask_img = cv2.resize(mask_img, (crop_img.shape[1], crop_img.shape[0]))
|
882 |
+
white_overlay = np.ones_like(crop_img) * 255 # 生成全白图
|
883 |
+
masked_result = np.where(mask_img[:, :, None] == 255, white_overlay, crop_img) # 只替换白色部分
|
884 |
+
|
885 |
+
# **保存叠加后的图像**
|
886 |
+
cv2.imwrite("crop_with_mask.jpg", masked_result)
|
887 |
+
print("✅ 叠加后的 Mask 图像已保存至 crop_with_mask.jpg")
|
yolo11n.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ebbc80d4a7680d14987a577cd21342b65ecfd94632bd9a8da63ae6417644ee1
|
3 |
+
size 5613764
|