Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -167,51 +167,6 @@ def show_box(box, ax):
|
|
| 167 |
w, h = box[2] - box[0], box[3] - box[1]
|
| 168 |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
| 169 |
|
| 170 |
-
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
|
| 171 |
-
combined_images = [] # List to store filenames of images with masks overlaid
|
| 172 |
-
mask_images = [] # List to store filenames of separate mask images
|
| 173 |
-
|
| 174 |
-
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 175 |
-
# ---- Original Image with Mask Overlaid ----
|
| 176 |
-
plt.figure(figsize=(10, 10))
|
| 177 |
-
plt.imshow(image)
|
| 178 |
-
show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
|
| 179 |
-
"""
|
| 180 |
-
if point_coords is not None:
|
| 181 |
-
assert input_labels is not None
|
| 182 |
-
show_points(point_coords, input_labels, plt.gca())
|
| 183 |
-
"""
|
| 184 |
-
if box_coords is not None:
|
| 185 |
-
show_box(box_coords, plt.gca())
|
| 186 |
-
if len(scores) > 1:
|
| 187 |
-
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
|
| 188 |
-
plt.axis('off')
|
| 189 |
-
|
| 190 |
-
# Save the figure as a JPG file
|
| 191 |
-
combined_filename = f"combined_image_{i+1}.jpg"
|
| 192 |
-
plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
|
| 193 |
-
combined_images.append(combined_filename)
|
| 194 |
-
|
| 195 |
-
plt.close() # Close the figure to free up memory
|
| 196 |
-
|
| 197 |
-
# ---- Separate Mask Image (White Mask on Black Background) ----
|
| 198 |
-
# Create a black image
|
| 199 |
-
mask_image = np.zeros_like(image, dtype=np.uint8)
|
| 200 |
-
|
| 201 |
-
# The mask is a binary array where the masked area is 1, else 0.
|
| 202 |
-
# Convert the mask to a white color in the mask_image
|
| 203 |
-
mask_layer = (mask > 0).astype(np.uint8) * 255
|
| 204 |
-
for c in range(3): # Assuming RGB, repeat mask for all channels
|
| 205 |
-
mask_image[:, :, c] = mask_layer
|
| 206 |
-
|
| 207 |
-
# Save the mask image
|
| 208 |
-
mask_filename = f"mask_image_{i+1}.png"
|
| 209 |
-
Image.fromarray(mask_image).save(mask_filename)
|
| 210 |
-
mask_images.append(mask_filename)
|
| 211 |
-
|
| 212 |
-
plt.close() # Close the figure to free up memory
|
| 213 |
-
|
| 214 |
-
return combined_images, mask_images
|
| 215 |
|
| 216 |
def load_model(checkpoint):
|
| 217 |
# Load model accordingly to user's choice
|
|
@@ -254,9 +209,11 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 254 |
# segment and track one object
|
| 255 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
| 256 |
|
|
|
|
| 257 |
# Add new point
|
| 258 |
if working_frame == None:
|
| 259 |
ann_frame_idx = 0 # the frame index we interact with
|
|
|
|
| 260 |
else:
|
| 261 |
# Use a regular expression to find the integer
|
| 262 |
match = re.search(r'frame_(\d+)', working_frame)
|
|
@@ -264,6 +221,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 264 |
# Extract the integer from the match
|
| 265 |
frame_number = int(match.group(1))
|
| 266 |
ann_frame_idx = frame_number
|
|
|
|
| 267 |
|
| 268 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
| 269 |
|
|
@@ -292,7 +250,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 292 |
plt.close()
|
| 293 |
torch.cuda.empty_cache()
|
| 294 |
|
| 295 |
-
return "output_first_frame.jpg", frame_names, inference_state
|
| 296 |
|
| 297 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
| 298 |
#### PROPAGATION ####
|
|
@@ -346,7 +304,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 346 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 347 |
|
| 348 |
if vis_frame_type == "check":
|
| 349 |
-
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=
|
| 350 |
elif vis_frame_type == "render":
|
| 351 |
# Create a video clip from the image sequence
|
| 352 |
original_fps = get_video_fps(video_in)
|
|
@@ -378,7 +336,7 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
|
|
| 378 |
frame_number = int(match.group(1))
|
| 379 |
ann_frame_idx = frame_number
|
| 380 |
new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
|
| 381 |
-
return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame
|
| 382 |
|
| 383 |
with gr.Blocks() as demo:
|
| 384 |
first_frame_path = gr.State()
|
|
@@ -453,19 +411,19 @@ with gr.Blocks() as demo:
|
|
| 453 |
queue = False
|
| 454 |
)
|
| 455 |
|
| 456 |
-
|
| 457 |
working_frame.change(
|
| 458 |
fn = switch_working_frame,
|
| 459 |
inputs = [working_frame, scanned_frames, video_frames_dir],
|
| 460 |
-
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map],
|
| 461 |
queue=False
|
| 462 |
)
|
| 463 |
-
|
| 464 |
|
| 465 |
submit_btn.click(
|
| 466 |
fn = sam_process,
|
| 467 |
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
|
| 468 |
-
outputs = [output_result, stored_frame_names, stored_inference_state]
|
| 469 |
)
|
| 470 |
|
| 471 |
propagate_btn.click(
|
|
|
|
| 167 |
w, h = box[2] - box[0], box[3] - box[1]
|
| 168 |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
def load_model(checkpoint):
|
| 172 |
# Load model accordingly to user's choice
|
|
|
|
| 209 |
# segment and track one object
|
| 210 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
| 211 |
|
| 212 |
+
new_working_frame = None
|
| 213 |
# Add new point
|
| 214 |
if working_frame == None:
|
| 215 |
ann_frame_idx = 0 # the frame index we interact with
|
| 216 |
+
new_working_frame = "frames_output_images/frame_0.jpg"
|
| 217 |
else:
|
| 218 |
# Use a regular expression to find the integer
|
| 219 |
match = re.search(r'frame_(\d+)', working_frame)
|
|
|
|
| 221 |
# Extract the integer from the match
|
| 222 |
frame_number = int(match.group(1))
|
| 223 |
ann_frame_idx = frame_number
|
| 224 |
+
new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
|
| 225 |
|
| 226 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
| 227 |
|
|
|
|
| 250 |
plt.close()
|
| 251 |
torch.cuda.empty_cache()
|
| 252 |
|
| 253 |
+
return "output_first_frame.jpg", frame_names, inference_state, new_working_frame
|
| 254 |
|
| 255 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
| 256 |
#### PROPAGATION ####
|
|
|
|
| 304 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 305 |
|
| 306 |
if vis_frame_type == "check":
|
| 307 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=True)
|
| 308 |
elif vis_frame_type == "render":
|
| 309 |
# Create a video clip from the image sequence
|
| 310 |
original_fps = get_video_fps(video_in)
|
|
|
|
| 336 |
frame_number = int(match.group(1))
|
| 337 |
ann_frame_idx = frame_number
|
| 338 |
new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
|
| 339 |
+
return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
|
| 340 |
|
| 341 |
with gr.Blocks() as demo:
|
| 342 |
first_frame_path = gr.State()
|
|
|
|
| 411 |
queue = False
|
| 412 |
)
|
| 413 |
|
| 414 |
+
|
| 415 |
working_frame.change(
|
| 416 |
fn = switch_working_frame,
|
| 417 |
inputs = [working_frame, scanned_frames, video_frames_dir],
|
| 418 |
+
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, working_frame],
|
| 419 |
queue=False
|
| 420 |
)
|
| 421 |
+
|
| 422 |
|
| 423 |
submit_btn.click(
|
| 424 |
fn = sam_process,
|
| 425 |
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
|
| 426 |
+
outputs = [output_result, stored_frame_names, stored_inference_state, working_frame]
|
| 427 |
)
|
| 428 |
|
| 429 |
propagate_btn.click(
|