jiuface commited on
Commit
2b27106
·
1 Parent(s): cfb74b6

add merge mutil mask

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -30,7 +30,7 @@ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
30
  @spaces.GPU(duration=20)
31
  @torch.inference_mode()
32
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
33
- def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0) -> Optional[Image.Image]:
34
  if not image_input:
35
  gr.Info("Please upload an image.")
36
  return None
@@ -72,6 +72,14 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
72
  if dilate > 0:
73
  mask = cv2.dilate(mask, kernel, iterations=1)
74
  images.append(mask)
 
 
 
 
 
 
 
 
75
  return images
76
 
77
 
@@ -84,6 +92,8 @@ with gr.Blocks() as demo:
84
  ['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
85
  )
86
  dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
 
 
87
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
88
  submit_button = gr.Button(value='Submit', variant='primary')
89
  with gr.Column():
@@ -91,7 +101,7 @@ with gr.Blocks() as demo:
91
  print(image, image_url, task_prompt, text_prompt, image_gallery)
92
  submit_button.click(
93
  fn = process_image,
94
- inputs = [image, image_url, task_prompt, text_prompt, dilate],
95
  outputs = [image_gallery,],
96
  show_api=False
97
  )
 
30
  @spaces.GPU(duration=20)
31
  @torch.inference_mode()
32
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
33
+ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False) -> Optional[Image.Image]:
34
  if not image_input:
35
  gr.Info("Please upload an image.")
36
  return None
 
72
  if dilate > 0:
73
  mask = cv2.dilate(mask, kernel, iterations=1)
74
  images.append(mask)
75
+
76
+ if merge_masks:
77
+ final_images = []
78
+ merged_mask = np.zeros_like(images[0], dtype=np.uint8)
79
+ for mask in images:
80
+ merged_mask = cv2.bitwise_or(merged_mask, mask)
81
+ final_images = [merged_mask]
82
+ return final_images
83
  return images
84
 
85
 
 
92
  ['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
93
  )
94
  dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
95
+ merge_masks = gr.Checkbox(label="Merge masks", value=False)
96
+
97
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
98
  submit_button = gr.Button(value='Submit', variant='primary')
99
  with gr.Column():
 
101
  print(image, image_url, task_prompt, text_prompt, image_gallery)
102
  submit_button.click(
103
  fn = process_image,
104
+ inputs = [image, image_url, task_prompt, text_prompt, dilate, merge_masks],
105
  outputs = [image_gallery,],
106
  show_api=False
107
  )