import os import torch import argparse import numpy as np from PIL import Image from ultralytics.models.sam import SAM2VideoPredictor def main(args): # Create SAM2VideoPredictor overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model="sam2_b.pt") predictor = SAM2VideoPredictor(overrides=overrides) video_name = args.video_name results = predictor(source=f"input_animatediff/{video_name}.mp4",points=[args.x, args.y],labels=[1]) for i in range(len(results)): mask = (results[i].masks.data).squeeze().to(torch.float16) mask = (mask * 255).cpu().numpy().astype(np.uint8) mask_image = Image.fromarray(mask) mask_dir = f'masks_animatediff/{video_name}' if not os.path.exists(mask_dir): os.makedirs(mask_dir) mask_image.save(mask_dir + f'/{str(i).zfill(3)}.png') if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process a video and generate masks using SAM2VideoPredictor.") parser.add_argument("--video_name", type=str, required=True, help="Name of the video file (without extension).") parser.add_argument("--x", type=int, default=255, help="X coordinate of the point.") parser.add_argument("--y", type=int, default=255, help="Y coordinate of the point.") args = parser.parse_args() main(args)