fffiloni commited on
Commit
a4918b7
·
verified ·
1 Parent(s): 66b8369

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -117,13 +117,23 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
117
 
118
  return combined_images, mask_images
119
 
120
- def sam_process(input_image, tracking_points, trackings_input_label):
121
  image = Image.open(input_image)
122
  image = np.array(image.convert("RGB"))
123
 
124
- sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
125
- model_cfg = "sam2_hiera_t.yaml"
126
-
 
 
 
 
 
 
 
 
 
 
127
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
128
 
129
  predictor = SAM2ImagePredictor(sam2_model)
@@ -164,7 +174,9 @@ with gr.Blocks() as demo:
164
  with gr.Row():
165
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
166
  clear_points_btn = gr.Button("Clear Points")
 
167
  with gr.Column():
 
168
  points_map = gr.Image(
169
  label="points map",
170
  type="filepath",
@@ -199,7 +211,7 @@ with gr.Blocks() as demo:
199
 
200
  submit_btn.click(
201
  fn = sam_process,
202
- inputs = [input_image, tracking_points, trackings_input_label],
203
  outputs = [output_result, output_result_mask]
204
  )
205
  demo.launch()
 
117
 
118
  return combined_images, mask_images
119
 
120
+ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
121
  image = Image.open(input_image)
122
  image = np.array(image.convert("RGB"))
123
 
124
+ if checkpoint == "tiny":
125
+ sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
126
+ model_cfg = "sam2_hiera_t.yaml"
127
+ elif checkpoint == "samll":
128
+ sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
129
+ model_cfg = "sam2_hiera_s.yaml"
130
+ elif checkpoint == "base-plus":
131
+ sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
132
+ model_cfg = "sam2_hiera_b+.yaml"
133
+ elif checkpoint == "large":
134
+ sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
135
+ model_cfg = "sam2_hiera_l.yaml"
136
+
137
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
138
 
139
  predictor = SAM2ImagePredictor(sam2_model)
 
174
  with gr.Row():
175
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
176
  clear_points_btn = gr.Button("Clear Points")
177
+
178
  with gr.Column():
179
+ checkpoint = gr.Dropbox(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
180
  points_map = gr.Image(
181
  label="points map",
182
  type="filepath",
 
211
 
212
  submit_btn.click(
213
  fn = sam_process,
214
+ inputs = [input_image, checkpoint, tracking_points, trackings_input_label],
215
  outputs = [output_result, output_result_mask]
216
  )
217
  demo.launch()