wondervictor commited on
Commit
77ae695
·
verified ·
1 Parent(s): 73ce42e

Update model/segment_anything_2/sam2/modeling/sam2_base.py

Browse files
model/segment_anything_2/sam2/modeling/sam2_base.py CHANGED
@@ -312,6 +312,11 @@ class SAM2Base(torch.nn.Module):
312
  sam_point_coords = torch.zeros(B, 1, 2, device=device)
313
  sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
314
 
 
 
 
 
 
315
  # b) Handle mask prompts
316
  if mask_inputs is not None:
317
  # If mask_inputs is provided, downsize it into low-res mask input if needed
@@ -333,7 +338,7 @@ class SAM2Base(torch.nn.Module):
333
  sam_mask_prompt = None
334
 
335
  sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
336
- points=(sam_point_coords, sam_point_labels),
337
  boxes=None,
338
  masks=sam_mask_prompt,
339
  text_embeds=text_inputs
 
312
  sam_point_coords = torch.zeros(B, 1, 2, device=device)
313
  sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
314
 
315
+ sam_point_prompt = (sam_point_coords, sam_point_labels)
316
+ # added by YxZhang to forbid contemporary using text prompt and point prompt
317
+ if text_inputs is not None:
318
+ sam_point_prompt = None
319
+
320
  # b) Handle mask prompts
321
  if mask_inputs is not None:
322
  # If mask_inputs is provided, downsize it into low-res mask input if needed
 
338
  sam_mask_prompt = None
339
 
340
  sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
341
+ points=sam_point_prompt,
342
  boxes=None,
343
  masks=sam_mask_prompt,
344
  text_embeds=text_inputs