File size: 1,374 Bytes
052f125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import torch
import argparse
import numpy as np
from PIL import Image
from ultralytics.models.sam import SAM2VideoPredictor


def main(args):
    
    # Create SAM2VideoPredictor
    overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model="sam2_b.pt")
    predictor = SAM2VideoPredictor(overrides=overrides)

    video_name = args.video_name
    results = predictor(source=f"input_animatediff/{video_name}.mp4",points=[args.x, args.y],labels=[1])

    for i in range(len(results)):
        mask = (results[i].masks.data).squeeze().to(torch.float16)
        mask = (mask * 255).cpu().numpy().astype(np.uint8)
        mask_image = Image.fromarray(mask)
        mask_dir = f'masks_animatediff/{video_name}'
        if not os.path.exists(mask_dir):  
            os.makedirs(mask_dir)        
        mask_image.save(mask_dir + f'/{str(i).zfill(3)}.png')

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="Process a video and generate masks using SAM2VideoPredictor.")
    parser.add_argument("--video_name", type=str, required=True, help="Name of the video file (without extension).")
    parser.add_argument("--x", type=int, default=255, help="X coordinate of the point.")
    parser.add_argument("--y", type=int, default=255, help="Y coordinate of the point.")
    
    args = parser.parse_args()
    main(args)