Evan73 commited on
Commit
d5c53f9
·
0 Parent(s):

fresh start without image history

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.jpg
2
+ *.png
3
+ driver_182_30frame/
4
+ *.jpg
5
+ *.png
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DBDLD
3
+ emoji: 🦀
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.25.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: The backdoor trigger demo
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sam2segment_structure import generate_trigger_crop
3
+ import os
4
+
5
+ # 模拟 lane_data(后期你可以动态读取 JSON 或用户上传)
6
+ dummy_lane_data = {
7
+ "lanes": [[-2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2]],
8
+ "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390],
9
+ "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"
10
+ }
11
+
12
+ def process_trigger_with_path(input_image, save_path):
13
+ # 确保目录存在
14
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
15
+
16
+ # 保存图片到指定路径
17
+ input_image.save(save_path)
18
+
19
+ # 设置 dummy_lane_data 中 raw_file 为当前路径
20
+ dummy_lane_data["raw_file"] = save_path
21
+
22
+ # 调用主处理函数
23
+ crop_path, mask_path = generate_trigger_crop(save_path, dummy_lane_data)
24
+ return crop_path, mask_path
25
+
26
+ demo = gr.Interface(
27
+ fn=process_trigger_with_path,
28
+ inputs=[
29
+ gr.Image(type="pil", label="Upload Image"),
30
+ gr.Textbox(label="Path to Save Image (e.g. driver_182_30frame/06010513_0036.MP4/00270.jpg)")
31
+ ],
32
+ outputs=[
33
+ gr.Image(type="filepath", label="Cropped Image"),
34
+ gr.Image(type="filepath", label="Cropped Mask")
35
+ ],
36
+ title="DBDLD Trigger Demo",
37
+ description="Upload an image and specify the target save path. The crop and mask will be generated accordingly."
38
+ )
39
+
40
+ demo.launch()
configs/sam2.1/sam2.1_hiera_b+.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [64, 64]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [64, 64]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ no_obj_embed_spatial: true
93
+ # use high-resolution feature map in the SAM mask decoder
94
+ use_high_res_features_in_sam: true
95
+ # output 3 masks on the first click on initial conditioning frames
96
+ multimask_output_in_sam: true
97
+ # SAM heads
98
+ iou_prediction_use_sigmoid: True
99
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100
+ use_obj_ptrs_in_encoder: true
101
+ add_tpos_enc_to_obj_ptrs: true
102
+ proj_tpos_enc_in_obj_ptrs: true
103
+ use_signed_tpos_enc_to_obj_ptrs: true
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
configs/sam2.1/sam2.1_hiera_l.yaml ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [64, 64]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [64, 64]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ compile_image_encoder: False
configs/sam2.1/sam2.1_hiera_s.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ no_obj_embed_spatial: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: true
105
+ proj_tpos_enc_in_obj_ptrs: true
106
+ use_signed_tpos_enc_to_obj_ptrs: true
107
+ only_obj_ptrs_in_the_past_for_eval: true
108
+ # object occlusion prediction
109
+ pred_obj_scores: true
110
+ pred_obj_scores_mlp: true
111
+ fixed_no_obj_ptr: true
112
+ # multimask tracking settings
113
+ multimask_output_for_tracking: true
114
+ use_multimask_token_for_obj_ptr: true
115
+ multimask_min_pt_num: 0
116
+ multimask_max_pt_num: 1
117
+ use_mlp_for_obj_ptr_proj: true
118
+ # Compilation flag
119
+ compile_image_encoder: False
configs/sam2.1/sam2.1_hiera_t.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ # HieraT does not currently support compilation, should always be set to False
121
+ compile_image_encoder: False
configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 8
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: null # PATH to MOSE JPEGImages folder
17
+ gt_folder: null # PATH to MOSE Annotations folder
18
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [64, 64]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [64, 64]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.utils.RepeatFactorWrapper
211
+ dataset:
212
+ _target_: training.dataset.utils.ConcatDataset
213
+ datasets:
214
+ - _target_: training.dataset.vos_dataset.VOSDataset
215
+ transforms: ${vos.train_transforms}
216
+ training: true
217
+ video_dataset:
218
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
219
+ img_folder: ${dataset.img_folder}
220
+ gt_folder: ${dataset.gt_folder}
221
+ file_list_txt: ${dataset.file_list_txt}
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: ${dataset.multiplier}
227
+ shuffle: True
228
+ num_workers: ${scratch.num_train_workers}
229
+ pin_memory: True
230
+ drop_last: True
231
+ collate_fn:
232
+ _target_: training.utils.data_utils.collate_fn
233
+ _partial_: true
234
+ dict_key: all
235
+
236
+ optim:
237
+ amp:
238
+ enabled: True
239
+ amp_dtype: bfloat16
240
+
241
+ optimizer:
242
+ _target_: torch.optim.AdamW
243
+
244
+ gradient_clip:
245
+ _target_: training.optimizer.GradientClipper
246
+ max_norm: 0.1
247
+ norm_type: 2
248
+
249
+ param_group_modifiers:
250
+ - _target_: training.optimizer.layer_decay_param_modifier
251
+ _partial_: True
252
+ layer_decay_value: 0.9
253
+ apply_to: 'image_encoder.trunk'
254
+ overrides:
255
+ - pattern: '*pos_embed*'
256
+ value: 1.0
257
+
258
+ options:
259
+ lr:
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.base_lr}
263
+ end_value: ${divide:${scratch.base_lr},10}
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.vision_lr}
267
+ end_value: ${divide:${scratch.vision_lr},10}
268
+ param_names:
269
+ - 'image_encoder.*'
270
+ weight_decay:
271
+ - scheduler:
272
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
273
+ value: 0.1
274
+ - scheduler:
275
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
276
+ value: 0.0
277
+ param_names:
278
+ - '*bias*'
279
+ module_cls_names: ['torch.nn.LayerNorm']
280
+
281
+ loss:
282
+ all:
283
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
284
+ weight_dict:
285
+ loss_mask: 20
286
+ loss_dice: 1
287
+ loss_iou: 1
288
+ loss_class: 1
289
+ supervise_all_iou: true
290
+ iou_use_l1_loss: true
291
+ pred_obj_scores: true
292
+ focal_gamma_obj_score: 0.0
293
+ focal_alpha_obj_score: -1.0
294
+
295
+ distributed:
296
+ backend: nccl
297
+ find_unused_parameters: True
298
+
299
+ logging:
300
+ tensorboard_writer:
301
+ _target_: training.utils.logger.make_tensorboard_logger
302
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
303
+ flush_secs: 120
304
+ should_log: True
305
+ log_dir: ${launcher.experiment_log_dir}/logs
306
+ log_freq: 10
307
+
308
+ # initialize from a SAM 2 checkpoint
309
+ checkpoint:
310
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
311
+ save_freq: 0 # 0 only last checkpoint is saved.
312
+ model_weight_initializer:
313
+ _partial_: True
314
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
315
+ strict: True
316
+ ignore_unexpected_keys: null
317
+ ignore_missing_keys: null
318
+
319
+ state_dict:
320
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
+ ckpt_state_dict_keys: ['model']
323
+
324
+ launcher:
325
+ num_nodes: 1
326
+ gpus_per_node: 8
327
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
328
+
329
+ # SLURM args if running on a cluster
330
+ submitit:
331
+ partition: null
332
+ account: null
333
+ qos: null
334
+ cpus_per_task: 10
335
+ use_cluster: false
336
+ timeout_hour: 24
337
+ name: null
338
+ port_range: [10000, 65000]
339
+
configs/sam2/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [64, 64]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [64, 64]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
configs/sam2/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [64, 64]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [64, 64]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
configs/sam2/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
configs/sam2/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ opencv-python
4
+ gradio
5
+ matplotlib
6
+ Pillow
7
+ ultralytics
8
+ diffusers
9
+ huggingface_hub
sam2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /data_sdf/yifan/sam2
sam2segment_structure.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sys,os
3
+ # sys.path.append("/home/yifan/sam2")
4
+ # sys.path.append("/data_sdf/yifan/miniconda3/envs/sam2/lib/python3.10/site-packages")
5
+ from huggingface_hub import hf_hub_download
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), "sam2"))
7
+ from sam2.build_sam import build_sam2
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+ import torch
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+ import cv2
13
+ import random
14
+ import warnings
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ device = torch.device("cuda")
17
+ sam2_checkpoint = hf_hub_download(
18
+ repo_id="Evan73/sam2-models",
19
+ filename="sam2.1_hiera_large.pt"
20
+ )
21
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
22
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
23
+ # global sam2_model
24
+ predictor = SAM2ImagePredictor(sam2_model)
25
+ from ultralytics import YOLO
26
+ from diffusers.utils import load_image
27
+ import pickle
28
+ import os
29
+ import math
30
+ heatmap_zip = hf_hub_download(
31
+ repo_id="Evan73/attention-heatmaps",
32
+ filename="attention_heatmaps.zip"
33
+ )
34
+ import zipfile
35
+ import os
36
+
37
+ with zipfile.ZipFile(heatmap_zip, 'r') as zip_ref:
38
+ zip_ref.extractall("heatmaps_lda")
39
+
40
+ with open("heatmaps_lda/attention_heatmaps.pkl", "rb") as f:
41
+ heatmap_dict = pickle.load(f)
42
+
43
+ def load_yolov5_model():
44
+ # 使用YOLOv5官方模型加载器(需要安装yolov5)
45
+ # model = torch.hub.load('ultralytics/yolov11', 'yolov11s') # 可以根据需要选择不同大小的模型
46
+ model = YOLO("yolo11n.pt")
47
+ class_names = model.names # class index to name mapping
48
+ print("YOLOv11 Class Names:")
49
+ for idx, name in class_names.items():
50
+ print(f"{idx}: {name}")
51
+ return model
52
+
53
+ # 检查点是否在汽车区域内
54
+ def is_point_in_car_area(point, model, image):
55
+ """
56
+ 检查给定的点是否在车辆区域内
57
+ - point: 点的坐标 (x, y)
58
+ - model: YOLO模型
59
+ - image: 输入的图像
60
+ """
61
+ # 使用YOLO模型进行物体检测
62
+ results = model(image) # 获取检测结果
63
+
64
+ # 获取汽车类别(根据模型调整类别ID)
65
+ # print("Detected classes:", results[0].boxes.cls.cpu().numpy())
66
+ car_class_id = [2, 5, 7] # COCO数据集中汽车类别通常为2,但需确认
67
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
68
+
69
+ # 遍历每个检测结果(支持批量处理,这里假设单张图像)
70
+ for result in results:
71
+ # 提取检测框的xyxy坐标、置信度、类别
72
+ boxes = result.boxes.xyxy.cpu().numpy() # 转换为左上和右下坐标
73
+ confidences = result.boxes.conf.cpu().numpy()
74
+ class_ids = result.boxes.cls.cpu().numpy().astype(int)
75
+
76
+ # 遍历每个检测框
77
+ for box, cls in zip(boxes, class_ids):
78
+ if cls in car_class_id:
79
+ x_min, y_min, x_max, y_max = box[:4]
80
+ # 绘制检测框(可选)
81
+ cv2.rectangle(image_bgr, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2)
82
+ # 检查点是否在框内
83
+ if (x_min <= point[0] <= x_max) and (y_min <= point[1] <= y_max):
84
+ cv2.imwrite("yolo_res.jpg", image_bgr)
85
+ return False
86
+ cv2.imwrite("yolo_res.jpg", image_bgr)
87
+ print(f"检测结果已保存至 yolo_res.jpg")
88
+ return True
89
+
90
+
91
+ def show_mask(mask, ax, image_path,random_color=False, borders=True, image=None, save_path=None):
92
+ """
93
+ 根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
94
+
95
+ 参数:
96
+ - `mask`: 掩码区域
97
+ - `ax`: 用于绘制的matplotlib轴
98
+ - `random_color`: 是否使用随机颜色
99
+ - `borders`: 是否显示边界
100
+ - `image`: 原始图像,用于绘制矩形框
101
+ - `save_path`: 保存结果图像的路径
102
+ """
103
+ if random_color:
104
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
105
+ else:
106
+ color = np.array([30/255, 144/255, 255/255, 0.6])
107
+
108
+ h, w = mask.shape[-2:]
109
+ mask = mask.astype(np.uint8)
110
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
111
+ cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
112
+ print("原始二值掩码已保存为 binary_mask.png")
113
+ if borders:
114
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
116
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=5)
117
+ # print(f"Mask unique values: {np.unique(mask)}")
118
+ # print(f"Max value in mask: {mask.max()}, Min value in mask: {mask.min()}")
119
+ # 如果提供了原始图像,绘制矩形框
120
+ # size = 100
121
+ colors = [
122
+ (255, 0, 0), # 红色
123
+ (0, 255, 0), # 绿色
124
+ (0, 0, 255), # 蓝色
125
+ (255, 255, 0), # 黄色
126
+ (255, 0, 255), # 品红色
127
+ (0, 255, 255), # 青色
128
+ (255, 128, 0), # 橙色
129
+ (128, 0, 255), # 紫色
130
+ (128, 128, 128), # 灰色
131
+ (0, 128, 0) # 深绿色
132
+ ]
133
+
134
+ for idx, contour in enumerate(contours):
135
+ x, y, w, h = cv2.boundingRect(contour)
136
+ print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
137
+ color = colors[idx % len(colors)]
138
+ cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
139
+ middle_save_path = "contours_colored_result.png"
140
+ cv2.imwrite(middle_save_path, image)
141
+ print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
142
+ if image is not None:
143
+ # 找到掩码的边界
144
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
145
+ for contour in contours:
146
+ x, y, w, h = cv2.boundingRect(contour)
147
+ # print(x, y, w, h)
148
+ if w > 50 and h > 50:
149
+ for size in range(90,40,-5):
150
+ for _ in range(100):
151
+ random_x1 = random.randint(x, x + w - 50)
152
+ random_y1 = random.randint(y, y + h - 50)
153
+ random_x2 = random_x1 - size
154
+ random_y2 = random_y1 - size
155
+ # print(random_x1, random_y1,random_x2,random_y2)
156
+ # 在原图上绘制矩形框
157
+ # 保存结果图像
158
+ try:
159
+ if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
160
+ cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
161
+ cv2.imwrite(save_path, image)
162
+ # generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
163
+ print(f"Image with rectangle saved at {save_path}")
164
+ return (random_x1,random_y1),(random_x2,random_y2)
165
+ except:
166
+ pass
167
+ # cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
168
+ # cv2.imwrite(save_path, image)
169
+ # print(f"Image with rectangle saved at {save_path}")
170
+ # break
171
+ for _ in range(100):
172
+ random_x1 = random.randint(x, x + w - 50)
173
+ random_y1 = random.randint(y, y + h - 50)
174
+ random_x2 = random_x1 + size
175
+ random_y2 = random_y1 + size
176
+ # print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
177
+ # 在原图上绘制矩形框
178
+ # 保存结果图像
179
+ try:
180
+ if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
181
+ cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
182
+ cv2.imwrite(save_path, image)
183
+ print(f"Image with rectangle saved at {save_path}")
184
+ # generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
185
+ return (random_x1,random_y1),(random_x2,random_y2)
186
+ except:
187
+ pass
188
+
189
+ ax.imshow(mask_image)
190
+ plt.axis('off')
191
+
192
+
193
+
194
+ def attention_mask(mask, ax, image_path,strategy="LOA",random_color=False, borders=True, image=None, save_path=None):
195
+ """
196
+ 根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
197
+
198
+ 参数:
199
+ - `mask`: 掩码区域
200
+ - `ax`: 用于绘制的matplotlib轴
201
+ - `random_color`: 是否使用随机颜色
202
+ - `borders`: 是否显示边界
203
+ - `image`: 原始图像,用于绘制矩形框
204
+ - `save_path`: 保存结果图像的路径
205
+ """
206
+ if random_color:
207
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
208
+ else:
209
+ color = np.array([30/255, 144/255, 255/255, 0.6])
210
+ orig_w, orig_h = image.shape[1],image.shape[0]
211
+ # print(image.shape)
212
+ h, w = mask.shape[-2:]
213
+ mask = mask.astype(np.uint8)
214
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
215
+ cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
216
+ print("原始二值掩码已保存为 binary_mask.png")
217
+ # if borders:
218
+ # contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
219
+ # contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
220
+ # mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
221
+ # colors = [
222
+ # (255, 0, 0), # 红色
223
+ # (0, 255, 0), # 绿色
224
+ # (0, 0, 255), # 蓝色
225
+ # (255, 255, 0), # 黄色
226
+ # (255, 0, 255), # 品红色
227
+ # (0, 255, 255), # 青色
228
+ # (255, 128, 0), # 橙色
229
+ # (128, 0, 255), # 紫色
230
+ # (128, 128, 128), # 灰色
231
+ # (0, 128, 0) # 深绿色
232
+ # ]
233
+ # # print(mask.shape)
234
+ # for idx, contour in enumerate(contours):
235
+ # x, y, w, h = cv2.boundingRect(contour)
236
+ # print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
237
+ # color = colors[idx % len(colors)]
238
+ # cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
239
+ # middle_save_path = "contours_colored_result.png"
240
+ # cv2.imwrite(middle_save_path, image)
241
+ # print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
242
+ candidates = []
243
+ path = image_path
244
+ cls_heatmap = heatmap_dict[path]['cls_heatmap']
245
+ reg_heatmap = heatmap_dict[path]['reg_heatmap']
246
+ font = cv2.FONT_HERSHEY_SIMPLEX
247
+ if strategy == "LDA":
248
+ combined = cls_heatmap.astype(np.float32)
249
+ if strategy == "LOA" or strategy == "LRA":
250
+ combined = reg_heatmap.astype(np.float32)
251
+ print(mask.shape)
252
+ mask = cv2.resize(mask, (combined.shape[1], combined.shape[0]), interpolation=cv2.INTER_NEAREST)
253
+ mask = (mask > 0.5).astype(np.uint8)
254
+ cv2.imwrite("crop_binary_mask.png", (mask * 255).astype(np.uint8))
255
+ print("处理后的裁剪二值掩码已保存为 crop_binary_mask.png")
256
+ print(combined.shape)
257
+ vis_image = cv2.imread(image_path)
258
+ vis_image = cv2.resize(vis_image,(combined.shape[1],combined.shape[0]))
259
+ mask_image = cv2.resize(mask_image,(combined.shape[1],combined.shape[0]))
260
+ image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
261
+ if borders:
262
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
263
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
264
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
265
+ colors = [
266
+ (255, 0, 0), # 红色
267
+ (0, 255, 0), # 绿色
268
+ (0, 0, 255), # 蓝色
269
+ (255, 255, 0), # 黄色
270
+ (255, 0, 255), # 品红色
271
+ (0, 255, 255), # 青色
272
+ (255, 128, 0), # 橙色
273
+ (128, 0, 255), # 紫色
274
+ (128, 128, 128), # 灰色
275
+ (0, 128, 0) # 深绿色
276
+ ]
277
+ # print(mask.shape)
278
+ for idx, contour in enumerate(contours):
279
+ x, y, w, h = cv2.boundingRect(contour)
280
+ print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
281
+ color = colors[idx % len(colors)]
282
+ cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
283
+ middle_save_path = "contours_colored_result.png"
284
+ cv2.imwrite(middle_save_path, image)
285
+ print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
286
+ if image is not None:
287
+ # 找到掩码的边界
288
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
289
+ # print(contours)
290
+ for contour in contours:
291
+ x, y, w, h = cv2.boundingRect(contour)
292
+ print("the contour is:",x, y, w, h)
293
+ if w > 50 and h > 50:
294
+ for size in range(50,40,-5):
295
+ for y_step in range(y, y+h - size,5):
296
+ for x_step in range(x, x+w - size,5):
297
+ x1, y1, x2, y2 = x_step, y_step, x_step + size, y_step + size
298
+ # print(mask[y1:y2, x1:x2].sum())
299
+ if mask[y1:y2, x1:x2].sum() >= size * size: # 掩码区域必须都在内部
300
+ heat_value = combined[y1:y2, x1:x2].mean()
301
+ # print("the heat_value is:",heat_value,y1,y2, x1,x2,combined.shape)
302
+ if not math.isnan(heat_value):
303
+ candidates.append(((x1, y1, x2, y2), heat_value))
304
+ cv2.rectangle(vis_image, (x1, y1), (x2, y2), (0, 255, 0), 1)
305
+ cv2.putText(vis_image, f'{heat_value:.1f}', (x1, y1 - 2), font, 0.4, (0, 0, 255), 1)
306
+ if not candidates:
307
+ print("⚠️ 没有找到满足掩码内区域的候选框")
308
+ else:
309
+ break
310
+ cv2.imwrite("attention_vis.jpg", vis_image)
311
+ print(f"Attention 候选框可视化已保存 attention_vis.jpg")
312
+ # 从高到低排序,选择热值最高的
313
+ candidates.sort(key=lambda x: x[1], reverse=True)
314
+ print(save_path,candidates[0],candidates[-1])
315
+ for (x1, y1, x2, y2), _ in candidates:
316
+ try:
317
+ if mask[y1, x1] == 1 and mask[y2, x2] == 1:
318
+ # 可视化 + 保存
319
+ if save_path:
320
+ image = cv2.imread(image_path)
321
+ image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
322
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
323
+ # os.makedirs(os.path.dirname(save_path), exist_ok=True)
324
+ cv2.imwrite(save_path, image)
325
+ print(f"Image with rectangle saved at {save_path}")
326
+ resize_w, resize_h = combined.shape[1],combined.shape[0]
327
+ scale_x = orig_w / resize_w
328
+ scale_y = orig_h / resize_h
329
+ x1_orig = int(x1 * scale_x)
330
+ x2_orig = int(x2 * scale_x)
331
+ y1_orig = int(y1 * scale_y)
332
+ y2_orig = int(y2 * scale_y)
333
+ cx = (x1_orig + x2_orig) // 2
334
+ cy = (y1_orig + y2_orig) // 2
335
+ target_size = 90
336
+ half = target_size // 2
337
+ x1_exp = max(0, cx - half)
338
+ y1_exp = max(0, cy - half)
339
+ x2_exp = min(orig_w - 1, cx + half)
340
+ y2_exp = min(orig_h - 1, cy + half)
341
+ print(f"扩展后的原图坐标: ({x1_exp}, {y1_exp}), ({x2_exp}, {y2_exp})")
342
+ image_full = cv2.imread(image_path) # 原图大小读取
343
+ cv2.rectangle(image_full, (x1_exp, y1_exp), (x2_exp, y2_exp), (0, 0, 255), 2)
344
+ cv2.imwrite("expanded_bbox_on_original.jpg", image_full)
345
+ print("📌 扩大后的候选框已绘制到原图并保存为 expanded_bbox_on_original.jpg")
346
+ return (x1_exp, y1_exp), (x2_exp, y2_exp)
347
+ except Exception as e:
348
+ print("the error is:",e)
349
+ pass # 若越界等问题,继续下一个
350
+ # for _ in range(100):
351
+ # random_x1 = random.randint(x, x + w - 50)
352
+ # random_y1 = random.randint(y, y + h - 50)
353
+ # random_x2 = random_x1 + size
354
+ # random_y2 = random_y1 + size
355
+ # # print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
356
+ # try:
357
+ # if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
358
+ # cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
359
+ # cv2.imwrite(save_path, image)
360
+ # print(f"Image with rectangle saved at {save_path}")
361
+ # return (random_x1,random_y1),(random_x2,random_y2)
362
+ # except:
363
+ # pass
364
+
365
+ ax.imshow(mask_image)
366
+ plt.axis('off')
367
+
368
+ def generate_gt_mask_from_intersection(random_rectangle, yolo_boxes, image, mask_img,sam2_model, threshold_iou):
369
+ """
370
+ 判断随机生成的矩形与YOLO的框是否足够接近,
371
+ 若满足条件则调用SAM获取精准掩码作为GT。
372
+ """
373
+ image_np = np.array(image)
374
+ x1_rect, y1_rect = random_rectangle[0]
375
+ x2_rect, y2_rect = random_rectangle[1]
376
+ rect_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
377
+ cv2.rectangle(rect_mask, (x1_rect, y1_rect), (x2_rect, y2_rect), color=255, thickness=-1)
378
+
379
+ rect_box = [min(x1_rect, x2_rect), min(y1_rect, y2_rect), max(x1_rect, x2_rect), max(y1_rect, y2_rect)]
380
+
381
+ for box in yolo_boxes:
382
+ iou = calculate_iou(rect_box, box)
383
+ print(f"与YOLO box的IoU为: {iou}, 阈值: {threshold_iou}")
384
+
385
+ if iou >= threshold_iou:
386
+ # 在YOLO框内随机取两个点
387
+ x_min, y_min, x_max, y_max = box
388
+ input_point1 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
389
+ input_point2 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
390
+ input_point3 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
391
+
392
+ # 使用SAM生成精准掩码
393
+ gt_mask = get_gt_mask_from_sam(image, sam2_model, [input_point1, input_point2,input_point3], rect_mask)
394
+ mask_img[gt_mask > 0] = 0
395
+ # 保存gt掩码
396
+ cv2.imwrite('gt_mask_from_sam.png', gt_mask)
397
+ print(f"SAM生成的GT掩码已保存至 gt_mask_from_sam.png")
398
+
399
+ return gt_mask,mask_img
400
+ h, w = image_np.shape[:2]
401
+ black_mask = np.zeros((h, w), dtype=np.uint8)
402
+ no_match_save_path = 'gt_mask_from_sam.png'
403
+ cv2.imwrite(no_match_save_path, black_mask)
404
+ print("未找到满足阈值条件的YOLO box。")
405
+ print(f"未匹配成功,保存空掩码图至 {no_match_save_path}")
406
+ return None,mask_img
407
+
408
+ def calculate_iou(boxA, boxB):
409
+ """计算两个box的IoU."""
410
+ xA = max(boxA[0], boxB[0])
411
+ yA = max(boxA[1], boxB[1])
412
+ xB = min(boxA[2], boxB[2])
413
+ yB = min(boxA[3], boxB[3])
414
+
415
+ inter_area = max(0, xB - xA + 1) * max(0, yB - yA + 1)
416
+
417
+ boxA_area = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
418
+ boxB_area = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
419
+
420
+ iou = inter_area / float(boxA_area + boxB_area - inter_area)
421
+ return iou
422
+
423
+ def get_gt_mask_from_sam(image, sam2_model, input_points, rect_mask):
424
+ """使用SAM根据两个点生成掩码,并保存选取点和掩码图"""
425
+ predictor = SAM2ImagePredictor(sam2_model)
426
+ print("load sam2")
427
+ predictor.set_image(image)
428
+
429
+ input_point_np = np.array(input_points)
430
+ input_label = np.array([1, 1,1])
431
+
432
+ masks, _, _ = predictor.predict(
433
+ point_coords=input_point_np,
434
+ point_labels=input_label,
435
+ multimask_output=False,
436
+ )
437
+
438
+ mask_img = masks[0].astype(np.uint8) * 255
439
+ # mask_img[rect_mask == 255] = 0 # 将 `random_rectangle` 区域设为黑色
440
+
441
+ # 保存SAM生成的掩码图
442
+ mask_save_path = 'sam_gt_mask.jpg'
443
+ cv2.imwrite(mask_save_path, mask_img)
444
+ print(f"SAM生成的掩码已保存至 {mask_save_path}")
445
+
446
+ # 把选取的两个点画在原图上
447
+ image_with_points = np.array(image).copy()
448
+ for point in input_points:
449
+ cv2.circle(image_with_points, point, radius=5, color=(255, 0, 0), thickness=-1)
450
+
451
+ # 保存带有标记点的原图
452
+ point_marked_save_path = 'image_with_points.jpg'
453
+ image_bgr = cv2.cvtColor(image_with_points, cv2.COLOR_RGB2BGR)
454
+ cv2.imwrite(point_marked_save_path, image_bgr)
455
+ print(f"带点标记的原图已保存至 {point_marked_save_path}")
456
+
457
+ return mask_img
458
+
459
+ def show_points(coords, labels, ax, marker_size=375):
460
+ pos_points = coords[labels==1]
461
+ neg_points = coords[labels==0]
462
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
463
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
464
+
465
+ def show_box(box, ax):
466
+ x0, y0 = box[0], box[1]
467
+ w, h = box[2] - box[0], box[3] - box[1]
468
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
469
+
470
+ def display_mask(mask, ax, random_color=False, borders = True):
471
+ if random_color:
472
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
473
+ else:
474
+ color = np.array([30/255, 144/255, 255/255, 0.6])
475
+ h, w = mask.shape[-2:]
476
+ mask = mask.astype(np.uint8)
477
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
478
+ if borders:
479
+ import cv2
480
+ contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
481
+ # Try to smooth contours
482
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
483
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
484
+ cv2.imwrite("check.jpg", mask_image)
485
+ ax.imshow(mask_image)
486
+
487
+ def random_points_below(point, radius, min_distance, model, image, max_attempts=100):
488
+ """
489
+ 在给定的point偏下方50像素的区域内,随机选择两个点直到满足条件。
490
+
491
+ 参数:
492
+ - point: (x, y) 格式的坐标
493
+ - radius: 随机点的最大半径
494
+ - min_distance: 两个随机点之间的最小距离
495
+ - max_attempts: 最大尝试次数,避免死循环
496
+
497
+ 返回:
498
+ - 两个随机点的坐标,如果没有找到合适的点则返回None
499
+ """
500
+ for _ in range(max_attempts):
501
+ # 在点的偏下方50像素区域内随机选择两个点
502
+ x1 = random.randint(point[0] - radius, point[0] + radius)
503
+ y1 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
504
+
505
+ x2 = random.randint(point[0] - radius, point[0] + radius)
506
+ y2 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
507
+
508
+ # 计算两个点之间的欧几里得距离
509
+ distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
510
+
511
+ # 检查距离条件
512
+ if distance >= min_distance and is_point_in_car_area((x1, y1), model, image) and is_point_in_car_area((x2, y2), model, image) :
513
+ return [(x1, y1), (x2, y2)]
514
+
515
+ # 如果超过最大尝试次数还没有找到合适的点,返回None
516
+ return None
517
+
518
+
519
+ def show_masks(image, masks, scores, image_path, strategy,point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None):
520
+ for i, (mask, score) in enumerate(zip(masks, scores)):
521
+ plt.figure(figsize=(10, 10))
522
+ plt.imshow(image)
523
+ display_mask(mask, plt.gca(), borders=borders)
524
+ if point_coords is not None:
525
+ assert input_labels is not None
526
+ show_points(point_coords, input_labels, plt.gca())
527
+ if box_coords is not None:
528
+ # boxes
529
+ show_box(box_coords, plt.gca())
530
+ plt.axis('off')
531
+ plt.savefig('check.jpg', bbox_inches='tight', pad_inches=0) # 保存图像
532
+ point1,point2 = attention_mask(mask, plt.gca(), image_path,strategy,borders=borders, image=image, save_path=save_path)
533
+ return point1,point2
534
+
535
+ def random_crop(image, target_width, target_height, mask_point1, mask_point2):
536
+ # global global_mask_point1_relative, global_mask_point2_relative
537
+ """从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
538
+ width, height = image.size
539
+ # 计算两个对角点的中点
540
+ center_x = (mask_point1[0] + mask_point2[0]) // 2
541
+ center_y = (mask_point1[1] + mask_point2[1]) // 2
542
+
543
+ # 计算裁剪区域的左上角和右下角
544
+ left = center_x - target_width // 2
545
+ top = center_y - target_height // 2
546
+ right = left + target_width
547
+ bottom = top + target_height
548
+
549
+ # 确保裁剪区域不会超出图像边界
550
+ if left < 0:
551
+ left = 0
552
+ right = target_width
553
+ if top < 0:
554
+ top = 0
555
+ bottom = target_height
556
+ if right > width:
557
+ right = width
558
+ left = width - target_width
559
+ if bottom > height:
560
+ bottom = height
561
+ top = height - target_height
562
+
563
+ # 计算 padding
564
+ top_padding = max(0, top)
565
+ left_padding = max(0, left)
566
+
567
+ # 裁剪图像
568
+ cropped_image = image.crop((left, top, right, bottom))
569
+
570
+ global_mask_point1_relative = (mask_point1[0] - left, mask_point1[1] - top)
571
+ global_mask_point2_relative = (mask_point2[0] - left, mask_point2[1] - top)
572
+ print("裁剪后点的相对位置为:")
573
+ print("mask_point1:", global_mask_point1_relative)
574
+ print("mask_point2:", global_mask_point2_relative)
575
+ return cropped_image, top_padding, left_padding,global_mask_point1_relative,global_mask_point2_relative
576
+
577
+ def get_left_right_points(lane_data,image_path):
578
+ lanes = lane_data["lanes"]
579
+ h_samples = lane_data["h_samples"]
580
+ model = load_yolov5_model()
581
+ # 找到h_samples的中间索引
582
+ mid_idx = len(h_samples) // 2
583
+ image = cv2.imread(image_path)
584
+ # 存储最左和最右的点
585
+ left_point = None
586
+ right_point = None
587
+ points = []
588
+ # 遍历每条车道线
589
+ for lane in lanes:
590
+ # 去掉值为-2的无效点
591
+ valid_points = [(x, y) for x, y in zip(lane, h_samples) if x != -2]
592
+
593
+ if valid_points:
594
+ if lane[mid_idx] != -2:
595
+ for i in range(mid_idx-2,0,-1):
596
+ left_point = lane[i]
597
+ print(left_point)
598
+ if lane[i] != -2:
599
+ point = (left_point,h_samples[i])
600
+ FLAG = is_point_in_car_area(point, model, image)
601
+ print(point,FLAG)
602
+ if FLAG:
603
+ points.append((left_point,h_samples[i]))
604
+ break
605
+ else:
606
+ point = (1540/2, 590/2+30) # 初始点坐标
607
+ radius = 50 # 随机点的最大半径
608
+ min_distance = 40 # 两个点之间的最小距离
609
+ points = random_points_below(point, radius, min_distance,model,image)
610
+ # first_non_minus_two = next((x for x in lane if x != -2), None)
611
+ # if first_non_minus_two:
612
+ # idx = lane.index(first_non_minus_two)
613
+ # for i in range(idx+5,idx,-1):
614
+ # left_point = lane[i]
615
+ # if lane[i] != -2:
616
+ # point = (left_point,h_samples[i])
617
+ # FLAG = is_point_in_car_area(point, model, image)
618
+ # if FLAG:
619
+ # points.append((left_point,h_samples[i]))
620
+ # break
621
+
622
+ # return left_point, right_point
623
+ return points
624
+
625
+ def sam2segment(image_path,points,strategy):
626
+ # print(points)
627
+ image = Image.open(image_path)
628
+ image = np.array(image.convert("RGB"))
629
+ predictor.set_image(image)
630
+ # print([points[0][0], points[0][1]])
631
+ input_point = np.array([(points[0][0], points[0][1])])
632
+ input_label = np.array([1])
633
+ masks, scores, logits = predictor.predict(
634
+ point_coords=input_point,
635
+ point_labels=input_label,
636
+ multimask_output=True,
637
+ )
638
+ sorted_ind = np.argsort(scores)[::-1]
639
+ masks = masks[sorted_ind]
640
+ scores = scores[sorted_ind]
641
+ logits = logits[sorted_ind]
642
+ #mask
643
+ mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
644
+ points_set = []
645
+ for point in points:
646
+ points_set.append((point[0], point[1]))
647
+ # print(points_set)
648
+ input_point = np.array(points_set)
649
+ input_label = np.array([1]*len(points_set))
650
+ masks, scores, _ = predictor.predict(
651
+ point_coords=input_point,
652
+ point_labels=input_label,
653
+ mask_input=mask_input[None, :, :],
654
+ multimask_output=False,
655
+ )
656
+ # random_mask_selection(image, masks, mask_index=0,output_path="cropped_image.jpg")
657
+ point1,point2 = show_masks(image, masks, scores, image_path, strategy,save_path="masked_image.jpg")
658
+ return point1,point2
659
+
660
+ def draw_point(image_path,points):
661
+ image = cv2.imread(image_path)
662
+ if image is not None:
663
+ # 绘制点
664
+ for point in points:
665
+ cv2.circle(image, point, radius=5, color=(0, 255, 0), thickness=-1) # 绿色点
666
+
667
+ # 保存图像
668
+ output_path = "output_image_with_points.jpg"
669
+ cv2.imwrite(output_path, image)
670
+ print(f"Image saved with points at {output_path}")
671
+ else:
672
+ print("Error: Image could not be loaded.")
673
+
674
+ def generate_mask(original_img_path, point1, point2):
675
+ """根据坐标生成掩码图像"""
676
+ # 读取原图
677
+ original_img = cv2.imread(original_img_path)
678
+
679
+ # 获取原图的尺寸
680
+ height, width, _ = original_img.shape
681
+
682
+ # 创建一个黑色的 mask 图像,尺寸与原图相同
683
+ mask = np.zeros((height, width), dtype=np.uint8)
684
+
685
+ # 计算3/4点
686
+ three_quarter_point = (
687
+ int(point1[0] + 0.95 * (point2[0] - point1[0])), # 计算 x 坐标
688
+ int(point1[1] + 0.95 * (point2[1] - point1[1])) # 计算 y 坐标
689
+ )
690
+
691
+ # 画出一个白色的矩形(将该区域填充为白色)
692
+ cv2.rectangle(mask, point1, three_quarter_point, color=255, thickness=-1)
693
+
694
+ # 保存生成的mask图像
695
+ mask_path = original_img_path.replace('test.jpg', 'mask_test.jpg')
696
+ cv2.imwrite(mask_path, mask)
697
+ print(mask_path)
698
+ return mask_path, point1, three_quarter_point
699
+
700
+ def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
701
+ """
702
+ 过滤 TuSimple `lanes`,只保留 `crop` 内的部分
703
+ """
704
+ cropped_lanes = []
705
+ for lane in lane_data["lanes"]:
706
+ cropped_lane = []
707
+ for x, y in zip(lane, lane_data["h_samples"]):
708
+ if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
709
+ cropped_lane.append((x, y))
710
+ # new_x = x - crop_x_min
711
+ # new_y = y - crop_y_min
712
+ # cropped_lane.append((new_x, new_y))
713
+ if cropped_lane:
714
+ cropped_lanes.append(cropped_lane)
715
+
716
+ return cropped_lanes
717
+
718
+
719
+ def generate_trigger_crop(image_path: str, lane_data: dict):
720
+ """
721
+ 输入一张图像路径,返回处理后的 crop 图像和 crop mask 图像路径。
722
+ """
723
+ # 1. 获取触发点
724
+ points = get_left_right_points(lane_data, image_path)
725
+ print(f"[INFO] 获取 trigger 点: {points}")
726
+ draw_point(image_path, points)
727
+
728
+ # 2. 使用 SAM2 获取 mask 点
729
+ image = load_image(image_path)
730
+ mask_point1, mask_point2 = sam2segment(image_path, points, "LDA")
731
+
732
+ # 3. Crop 原图
733
+ input_image, *_ = random_crop(image, 512, 512, mask_point1, mask_point2)
734
+ input_crop_path = "crop.jpg"
735
+ input_image.save(input_crop_path)
736
+
737
+ # 4. 生成 trigger mask
738
+ mask_path, point1, point2 = generate_mask(image_path, mask_point1, mask_point2)
739
+ mask_img = load_image(mask_path)
740
+ mask_img, *_ = random_crop(mask_img, 512, 512, mask_point1, mask_point2)
741
+ crop_mask_path = "crop_mask.jpg"
742
+ cv2.imwrite(crop_mask_path, np.array(mask_img))
743
+
744
+ return input_crop_path, crop_mask_path
745
+
746
+ if __name__ == "__main__":
747
+ lane_data = {"lanes": [[-2, -2, -2, -2, -2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, 818, 801, 784, 768, 751, 734, 717, 701, 685, 668, 651, 634, 618, 601, 585, 568, 551, 535, 518, 502, 484, 468, 451, 435, 418, 401, 385, 368, 351, 335, 318, 301, 287], [-2, -2, -2, -2, -2, -2, -2, 863, 872, 881, 890, 899, 908, 918, 927, 936, 945, 954, 964, 972, 982, 991, 1000, 1009, 1018, 1027, 1036, 1046, 1055, 1064, 1073, 1082, 1091, 1100, 1109, 1119, 1128, 1137, 1146, 1154]], "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590], "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"}
748
+
749
+ image_path = "driver_182_30frame/06010513_0036.MP4/00270.jpg"
750
+ points = get_left_right_points(lane_data,image_path)
751
+ print(points)
752
+ draw_point(image_path,points)
753
+ # left_point, right_point = get_left_right_points(lane_data)
754
+ # print(f"Left point: {left_point}, Right point: {right_point}")
755
+ # sam2segment(image_path,left_point, right_point)
756
+ image = load_image(image_path)
757
+ mask_point1,mask_point2 = sam2segment(image_path,points,"LDA")
758
+ input_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(image, 512, 512, mask_point1, mask_point2)
759
+ input_image.save("crop.jpg") # 直接用 PIL 的 `save()` 方法
760
+ print(f"Image saved with points at crop.jpg")
761
+ mask_path, point1, point2 = generate_mask('culane_test.jpg', mask_point1, mask_point2)
762
+ mask_img = load_image(mask_path)
763
+ mask_img,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(mask_img, 512, 512,mask_point1,mask_point2)
764
+
765
+ mask_img = np.array(mask_img)
766
+ # print(mask_img.shape)
767
+ model = load_yolov5_model()
768
+ yolo_results = model(input_image)
769
+ yolo_boxes = []
770
+ car_class_id = [2, 5, 7] # 汽车、巴士、卡车等类别ID,根据实际情况调整
771
+
772
+ for result in yolo_results:
773
+ boxes = result.boxes.xyxy.cpu().numpy()
774
+ class_ids = result.boxes.cls.cpu().numpy().astype(int)
775
+
776
+ for box, cls in zip(boxes, class_ids):
777
+ if cls in car_class_id:
778
+ x_min, y_min, x_max, y_max = box[:4]
779
+ yolo_boxes.append([int(x_min), int(y_min), int(x_max), int(y_max)])
780
+ _,mask_img=generate_gt_mask_from_intersection([global_mask_point1_relative,global_mask_point2_relative], yolo_boxes, input_image, mask_img,sam2_model, threshold_iou=0.01)
781
+ cv2.imwrite("crop_mask.jpg", mask_img)
782
+
783
+ print("Mask 已成功保存至 crop_mask.jpg")
784
+ crop_x_min = min(mask_point1[0], mask_point2[0])
785
+ crop_x_max = max(mask_point1[0], mask_point2[0])
786
+ crop_y_min = min(mask_point1[1], mask_point2[1])
787
+ crop_y_max = max(mask_point1[1], mask_point2[1])
788
+
789
+
790
+ def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
791
+ """
792
+ 过滤 TuSimple `lanes`,只保留 `crop` 内的部分
793
+ """
794
+ cropped_lanes = []
795
+ for lane in lane_data["lanes"]:
796
+ cropped_lane = []
797
+ for x, y in zip(lane, lane_data["h_samples"]):
798
+ if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
799
+ cropped_lane.append((x, y))
800
+ # new_x = x - crop_x_min
801
+ # new_y = y - crop_y_min
802
+ # cropped_lane.append((new_x, new_y))
803
+ if cropped_lane:
804
+ cropped_lanes.append(cropped_lane)
805
+
806
+ return cropped_lanes
807
+
808
+ # **获取在 crop 范围内的 lane**
809
+ cropped_lanes = extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max)
810
+ # print(cropped_lanes)
811
+ # def draw_lane_mask(image, lanes):
812
+ # """
813
+ # 画出 `lane_mask` 只在 `crop` 图像中
814
+ # """
815
+ # height, width, _ = image.shape
816
+ # lane_mask = np.zeros((height, width), dtype=np.uint8)
817
+
818
+ # for lane in lanes:
819
+ # points = np.array(lane, dtype=np.int32)
820
+ # cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=5)
821
+
822
+ # return lane_mask
823
+
824
+ # crop_image = load_image("crop.jpg").convert("RGB")
825
+ # crop_image = np.array(crop_image)
826
+ # lane_mask = draw_lane_mask(crop_image, cropped_lanes)
827
+ def draw_lane_mask_on_original(image, cropped_lanes):
828
+ """
829
+ 在原图上绘制 **仅包含 cropped_lanes** 的车道线
830
+ """
831
+ height, width, _ = image.shape
832
+ lane_mask = np.zeros((height, width), dtype=np.uint8)
833
+
834
+ for lane in cropped_lanes:
835
+ points = np.array(lane, dtype=np.int32)
836
+ cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=10)
837
+
838
+ return lane_mask
839
+
840
+ def random_crop_lane(image, target_width, target_height, mask_point1, mask_point2):
841
+ """从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
842
+
843
+ # **确保 image 是 NumPy 数组**
844
+ if isinstance(image, Image.Image):
845
+ image = np.array(image)
846
+
847
+ height, width = image.shape[:2] # 获取 NumPy 数组的大小
848
+
849
+ # 计算两个对角点的中点
850
+ center_x = (mask_point1[0] + mask_point2[0]) // 2
851
+ center_y = (mask_point1[1] + mask_point2[1]) // 2
852
+
853
+ # 计算裁剪区域的左上角和右下角
854
+ left = max(0, center_x - target_width // 2)
855
+ top = max(0, center_y - target_height // 2)
856
+ right = min(width, left + target_width)
857
+ bottom = min(height, top + target_height)
858
+
859
+ # 计算 padding(如果裁剪区域超出边界)
860
+ top_padding = max(0, target_height - (bottom - top))
861
+ left_padding = max(0, target_width - (right - left))
862
+
863
+ # **使用 NumPy 进行裁剪**
864
+ cropped_image = image[top:bottom, left:right]
865
+
866
+ return cropped_image, top_padding, left_padding
867
+ # **绘制 lane_mask 在原图上**
868
+ raw_image = np.array(load_image(image_path).convert("RGB"))
869
+ lane_mask = draw_lane_mask_on_original(raw_image, cropped_lanes)
870
+ lane_mask_pil = Image.fromarray(lane_mask)
871
+ crop_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(lane_mask_pil, 512, 512,mask_point1,mask_point2)
872
+
873
+ # **保存 lane_mask**
874
+ crop_image.save("lane_mask_crop.jpg")
875
+ print("✅ 车道 Mask 已保存为 lane_mask_crop.jpg")
876
+
877
+ crop_img = cv2.imread("crop.jpg") # 读取原图(BGR格式)
878
+ mask_img = cv2.imread("crop_mask.jpg", cv2.IMREAD_GRAYSCALE) # 读取掩码(灰度图)
879
+ if crop_img.shape[:2] != mask_img.shape:
880
+ print("⚠️ Resizing mask to match crop image size...")
881
+ mask_img = cv2.resize(mask_img, (crop_img.shape[1], crop_img.shape[0]))
882
+ white_overlay = np.ones_like(crop_img) * 255 # 生成全白图
883
+ masked_result = np.where(mask_img[:, :, None] == 255, white_overlay, crop_img) # 只替换白色部分
884
+
885
+ # **保存叠加后的图像**
886
+ cv2.imwrite("crop_with_mask.jpg", masked_result)
887
+ print("✅ 叠加后的 Mask 图像已保存至 crop_with_mask.jpg")
yolo11n.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ebbc80d4a7680d14987a577cd21342b65ecfd94632bd9a8da63ae6417644ee1
3
+ size 5613764