yansong1616 commited on
Commit
45afc96
·
verified ·
1 Parent(s): e4d53e2

Update SAM2/sam2/utils/misc.py

Browse files
Files changed (1) hide show
  1. SAM2/sam2/utils/misc.py +243 -242
SAM2/sam2/utils/misc.py CHANGED
@@ -1,242 +1,243 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
- import warnings
9
- from threading import Thread
10
-
11
- import numpy as np
12
- import torch
13
- from PIL import Image
14
- from tqdm import tqdm
15
-
16
-
17
- def get_sdpa_settings():
18
- if torch.cuda.is_available():
19
- old_gpu = torch.cuda.get_device_properties(0).major < 7
20
- # only use Flash Attention on Ampere (8.0) or newer GPUs
21
- use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
22
- if not use_flash_attn:
23
- warnings.warn(
24
- "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
25
- category=UserWarning,
26
- stacklevel=2,
27
- )
28
- # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
29
- # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
30
- pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
31
- if pytorch_version < (2, 2):
32
- warnings.warn(
33
- f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
34
- "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
35
- category=UserWarning,
36
- stacklevel=2,
37
- )
38
- math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
39
- else:
40
- old_gpu = True
41
- use_flash_attn = False
42
- math_kernel_on = True
43
-
44
- #guo yansong: TODO 本机可能不支持Flash Attention,所以这里强制不用Flash Attention
45
- #return True, False, True
46
-
47
- return old_gpu, use_flash_attn, math_kernel_on
48
-
49
-
50
- def get_connected_components(mask):
51
- """
52
- Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
53
-
54
- Inputs:
55
- - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
56
- background.
57
-
58
- Outputs:
59
- - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
60
- for foreground pixels and 0 for background pixels.
61
- - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
62
- components for foreground pixels and 0 for background pixels.
63
- """
64
- from sam2 import _C
65
-
66
- return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
67
-
68
-
69
- def mask_to_box(masks: torch.Tensor):
70
- """
71
- compute bounding box given an input mask
72
-
73
- Inputs:
74
- - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
75
-
76
- Returns:
77
- - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
78
- """
79
- B, _, h, w = masks.shape
80
- device = masks.device
81
- xs = torch.arange(w, device=device, dtype=torch.int32)
82
- ys = torch.arange(h, device=device, dtype=torch.int32)
83
- grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
84
- grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
85
- grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
86
- min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
87
- max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
88
- min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
89
- max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
90
- bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
91
-
92
- return bbox_coords
93
-
94
-
95
- def _load_img_as_tensor(img_path, image_size):
96
- img_pil = Image.open(img_path)
97
- img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
98
- if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
99
- img_np = img_np / 255.0
100
- else:
101
- raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
102
- img = torch.from_numpy(img_np).permute(2, 0, 1)
103
- video_width, video_height = img_pil.size # the original video size
104
- return img, video_height, video_width
105
-
106
-
107
- class AsyncVideoFrameLoader:
108
- """
109
- A list of video frames to be load asynchronously without blocking session start.
110
- """
111
-
112
- def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
113
- self.img_paths = img_paths
114
- self.image_size = image_size
115
- self.offload_video_to_cpu = offload_video_to_cpu
116
- self.img_mean = img_mean
117
- self.img_std = img_std
118
- # items in `self._images` will be loaded asynchronously
119
- self.images = [None] * len(img_paths)
120
- # catch and raise any exceptions in the async loading thread
121
- self.exception = None
122
- # video_height and video_width be filled when loading the first image
123
- self.video_height = None
124
- self.video_width = None
125
-
126
- # load the first frame to fill video_height and video_width and also
127
- # to cache it (since it's most likely where the user will click)
128
- self.__getitem__(0)
129
-
130
- # load the rest of frames asynchronously without blocking the session start
131
- def _load_frames():
132
- try:
133
- for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
134
- self.__getitem__(n)
135
- except Exception as e:
136
- self.exception = e
137
-
138
- self.thread = Thread(target=_load_frames, daemon=True)
139
- self.thread.start()
140
-
141
- def __getitem__(self, index):
142
- if self.exception is not None:
143
- raise RuntimeError("Failure in frame loading thread") from self.exception
144
-
145
- img = self.images[index]
146
- if img is not None:
147
- return img
148
-
149
- img, video_height, video_width = _load_img_as_tensor(
150
- self.img_paths[index], self.image_size
151
- )
152
- self.video_height = video_height
153
- self.video_width = video_width
154
- # normalize by mean and std
155
- img -= self.img_mean
156
- img /= self.img_std
157
- if not self.offload_video_to_cpu:
158
- img = img.cuda(non_blocking=True)
159
- self.images[index] = img
160
- return img
161
-
162
- def __len__(self):
163
- return len(self.images)
164
-
165
-
166
- def load_video_frames(
167
- video_path,
168
- image_size,
169
- offload_video_to_cpu,
170
- img_mean=(0.485, 0.456, 0.406),
171
- img_std=(0.229, 0.224, 0.225),
172
- async_loading_frames=False,
173
- ):
174
- """
175
- Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
176
-
177
- The frames are resized to image_size x image_size and are loaded to GPU if
178
- `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
179
-
180
- You can load a frame asynchronously by setting `async_loading_frames` to `True`.
181
- """
182
- if isinstance(video_path, str) and os.path.isdir(video_path):
183
- jpg_folder = video_path
184
- else:
185
- raise NotImplementedError("Only JPEG frames are supported at this moment")
186
-
187
-
188
- frame_names = [
189
- p
190
- for p in sorted(os.listdir(jpg_folder))
191
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
192
- ]
193
-
194
- num_frames = len(frame_names)
195
- if num_frames == 0:
196
- raise RuntimeError(f"no images found in {jpg_folder}")
197
- img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
198
- img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
199
- img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
200
-
201
- if async_loading_frames:
202
- lazy_images = AsyncVideoFrameLoader(
203
- img_paths, image_size, offload_video_to_cpu, img_mean, img_std
204
- )
205
- return lazy_images, lazy_images.video_height, lazy_images.video_width
206
-
207
- images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
208
- for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
209
- images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
210
- if not offload_video_to_cpu:
211
- images = images.cuda()
212
- img_mean = img_mean.cuda()
213
- img_std = img_std.cuda()
214
- # normalize by mean and std
215
- images -= img_mean
216
- images /= img_std
217
- return images, video_height, video_width
218
-
219
-
220
- def fill_holes_in_mask_scores(mask, max_area):
221
- """
222
- A post processor to fill small holes in mask scores with area under `max_area`.
223
- """
224
- # Holes are those connected components in background with area <= self.max_area
225
- # (background regions are those with mask scores <= 0)
226
- assert max_area > 0, "max_area must be positive"
227
- labels, areas = get_connected_components(mask <= 0)
228
- is_hole = (labels > 0) & (areas <= max_area)
229
- # We fill holes with a small positive mask score (0.1) to change them to foreground.
230
- mask = torch.where(is_hole, 0.1, mask)
231
- return mask
232
-
233
-
234
- def concat_points(old_point_inputs, new_points, new_labels):
235
- """Add new points and labels to previous point inputs (add at the end)."""
236
- if old_point_inputs is None:
237
- points, labels = new_points, new_labels
238
- else:
239
- points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
240
- labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
241
-
242
- return {"point_coords": points, "point_labels": labels}
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import warnings
9
+ from threading import Thread
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+
16
+
17
+ def get_sdpa_settings():
18
+ if torch.cuda.is_available():
19
+ old_gpu = torch.cuda.get_device_properties(0).major < 7
20
+ # only use Flash Attention on Ampere (8.0) or newer GPUs
21
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
22
+ if not use_flash_attn:
23
+ warnings.warn(
24
+ "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
25
+ category=UserWarning,
26
+ stacklevel=2,
27
+ )
28
+ # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
29
+ # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
30
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
31
+ if pytorch_version < (2, 2):
32
+ warnings.warn(
33
+ f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
34
+ "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
35
+ category=UserWarning,
36
+ stacklevel=2,
37
+ )
38
+ math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
39
+ else:
40
+ old_gpu = True
41
+ use_flash_attn = False
42
+ math_kernel_on = True
43
+
44
+ #guo yansong: TODO 本机可能不支持Flash Attention,所以这里强制不用Flash Attention
45
+ #return True, False, True
46
+
47
+ return old_gpu, use_flash_attn, math_kernel_on
48
+
49
+
50
+ def get_connected_components(mask):
51
+ """
52
+ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
53
+
54
+ Inputs:
55
+ - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
56
+ background.
57
+
58
+ Outputs:
59
+ - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
60
+ for foreground pixels and 0 for background pixels.
61
+ - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
62
+ components for foreground pixels and 0 for background pixels.
63
+ """
64
+ from sam2 import _C
65
+
66
+ return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
67
+
68
+
69
+ def mask_to_box(masks: torch.Tensor):
70
+ """
71
+ compute bounding box given an input mask
72
+
73
+ Inputs:
74
+ - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
75
+
76
+ Returns:
77
+ - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
78
+ """
79
+ B, _, h, w = masks.shape
80
+ device = masks.device
81
+ xs = torch.arange(w, device=device, dtype=torch.int32)
82
+ ys = torch.arange(h, device=device, dtype=torch.int32)
83
+ grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
84
+ grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
85
+ grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
86
+ min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
87
+ max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
88
+ min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
89
+ max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
90
+ bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
91
+
92
+ return bbox_coords
93
+
94
+
95
+ def _load_img_as_tensor(img_path, image_size):
96
+ img_pil = Image.open(img_path)
97
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
98
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
99
+ img_np = img_np / 255.0
100
+ else:
101
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
102
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
103
+ video_width, video_height = img_pil.size # the original video size
104
+ return img, video_height, video_width
105
+
106
+
107
+ class AsyncVideoFrameLoader:
108
+ """
109
+ A list of video frames to be load asynchronously without blocking session start.
110
+ """
111
+
112
+ def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
113
+ self.img_paths = img_paths
114
+ self.image_size = image_size
115
+ self.offload_video_to_cpu = offload_video_to_cpu
116
+ self.img_mean = img_mean
117
+ self.img_std = img_std
118
+ # items in `self._images` will be loaded asynchronously
119
+ self.images = [None] * len(img_paths)
120
+ # catch and raise any exceptions in the async loading thread
121
+ self.exception = None
122
+ # video_height and video_width be filled when loading the first image
123
+ self.video_height = None
124
+ self.video_width = None
125
+
126
+ # load the first frame to fill video_height and video_width and also
127
+ # to cache it (since it's most likely where the user will click)
128
+ self.__getitem__(0)
129
+
130
+ # load the rest of frames asynchronously without blocking the session start
131
+ def _load_frames():
132
+ try:
133
+ for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
134
+ self.__getitem__(n)
135
+ except Exception as e:
136
+ self.exception = e
137
+
138
+ self.thread = Thread(target=_load_frames, daemon=True)
139
+ self.thread.start()
140
+
141
+ def __getitem__(self, index):
142
+ if self.exception is not None:
143
+ raise RuntimeError("Failure in frame loading thread") from self.exception
144
+
145
+ img = self.images[index]
146
+ if img is not None:
147
+ return img
148
+
149
+ img, video_height, video_width = _load_img_as_tensor(
150
+ self.img_paths[index], self.image_size
151
+ )
152
+ self.video_height = video_height
153
+ self.video_width = video_width
154
+ # normalize by mean and std
155
+ img -= self.img_mean
156
+ img /= self.img_std
157
+ if not self.offload_video_to_cpu:
158
+ img = img.cuda(non_blocking=True)
159
+ self.images[index] = img
160
+ return img
161
+
162
+ def __len__(self):
163
+ return len(self.images)
164
+
165
+
166
+ def load_video_frames(
167
+ video_path,
168
+ image_size,
169
+ offload_video_to_cpu,
170
+ img_mean=(0.485, 0.456, 0.406),
171
+ img_std=(0.229, 0.224, 0.225),
172
+ async_loading_frames=False,
173
+ ):
174
+ """
175
+ Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
176
+
177
+ The frames are resized to image_size x image_size and are loaded to GPU if
178
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
179
+
180
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
181
+ """
182
+ if isinstance(video_path, str) and os.path.isdir(video_path):
183
+ jpg_folder = video_path
184
+ else:
185
+ raise NotImplementedError("Only JPEG frames are supported at this moment")
186
+
187
+
188
+ frame_names = [
189
+ p
190
+ for p in sorted(os.listdir(jpg_folder))
191
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
192
+ ]
193
+
194
+ num_frames = len(frame_names)
195
+ if num_frames == 0:
196
+ raise RuntimeError(f"no images found in {jpg_folder}")
197
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
198
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
199
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
200
+
201
+ if async_loading_frames:
202
+ lazy_images = AsyncVideoFrameLoader(
203
+ img_paths, image_size, offload_video_to_cpu, img_mean, img_std
204
+ )
205
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
206
+
207
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
208
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
209
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
210
+ if not offload_video_to_cpu:
211
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
212
+ images = images.to(device)
213
+ img_mean = img_mean.to(device)
214
+ img_std = img_std.to(device)
215
+ # normalize by mean and std
216
+ images -= img_mean
217
+ images /= img_std
218
+ return images, video_height, video_width
219
+
220
+
221
+ def fill_holes_in_mask_scores(mask, max_area):
222
+ """
223
+ A post processor to fill small holes in mask scores with area under `max_area`.
224
+ """
225
+ # Holes are those connected components in background with area <= self.max_area
226
+ # (background regions are those with mask scores <= 0)
227
+ assert max_area > 0, "max_area must be positive"
228
+ labels, areas = get_connected_components(mask <= 0)
229
+ is_hole = (labels > 0) & (areas <= max_area)
230
+ # We fill holes with a small positive mask score (0.1) to change them to foreground.
231
+ mask = torch.where(is_hole, 0.1, mask)
232
+ return mask
233
+
234
+
235
+ def concat_points(old_point_inputs, new_points, new_labels):
236
+ """Add new points and labels to previous point inputs (add at the end)."""
237
+ if old_point_inputs is None:
238
+ points, labels = new_points, new_labels
239
+ else:
240
+ points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
241
+ labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
242
+
243
+ return {"point_coords": points, "point_labels": labels}