Spaces:
Paused
Paused
import os | |
import cv2 | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from torchvision import transforms | |
import imageio | |
import argparse | |
import sys | |
sys.path.append("RAFT/core") | |
from raft import RAFT | |
from utils.utils import InputPadder | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_raft_model(ckpt_path): | |
args = argparse.Namespace( | |
small=False, | |
mixed_precision=False, | |
alternate_corr=False, | |
dropout=0.0, | |
max_depth=8, | |
depth_network=False, | |
depth_residual=False, | |
depth_scale=1.0 | |
) | |
model = torch.nn.DataParallel(RAFT(args)) | |
model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE)) | |
return model.module.to(DEVICE).eval() | |
def run_masking(video_path, output_path, mask_path, raft): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"Failed to open video: {video_path}") | |
return | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
ok, first = cap.read() | |
if not ok: | |
print(f"Failed to read first frame in {video_path}") | |
return | |
resize_to = (720, 480) | |
first = cv2.resize(first, resize_to) | |
H, W, _ = first.shape | |
area_thresh = (H * W) // 6 | |
grid = np.stack(np.meshgrid(np.arange(W), np.arange(H)), -1).astype(np.float32) | |
pos = grid.copy() | |
vis = np.ones((H, W), dtype=bool) | |
writer = imageio.get_writer(output_path, fps=int(fps)) | |
prev = first.copy() | |
frames_since_corr = 0 | |
freeze_mask = False | |
frozen_mask = None | |
all_masks = [] | |
writer.append_data(first[:, :, ::-1]) | |
all_masks.append(np.ones((H, W), dtype=bool)) | |
def to_tensor(bgr): | |
return transforms.ToTensor()(bgr).unsqueeze(0).to(DEVICE) | |
def raft_flow(img1_bgr, img2_bgr): | |
t1, t2 = to_tensor(img1_bgr), to_tensor(img2_bgr) | |
padder = InputPadder(t1.shape) | |
i1, i2 = padder.pad(t1, t2) | |
with torch.no_grad(): | |
_, flow = raft(i1, i2, iters=20, test_mode=True) | |
return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy() | |
for _ in range(1, n_frames): | |
ok, cur = cap.read() | |
if not ok: | |
break | |
cur = cv2.resize(cur, resize_to) | |
if not freeze_mask: | |
flow_fw = raft_flow(prev, cur) | |
pos += flow_fw | |
frames_since_corr += 1 | |
x_ok = (0 <= pos[..., 0]) & (pos[..., 0] < W) | |
y_ok = (0 <= pos[..., 1]) & (pos[..., 1] < H) | |
vis &= x_ok & y_ok | |
m = np.zeros((H, W), np.uint8) | |
ys, xs = np.where(vis) | |
px = np.round(pos[ys, xs, 0]).astype(int) | |
py = np.round(pos[ys, xs, 1]).astype(int) | |
inb = (0 <= px) & (px < W) & (0 <= py) & (py < H) | |
m[py[inb], px[inb]] = 1 | |
m = cv2.dilate(m, np.ones((2, 2), np.uint8)) | |
visible_ratio = m.sum() / (H * W) | |
if visible_ratio < 0.3: | |
flow_0t = raft_flow(first, cur) | |
pos = grid + flow_0t | |
vis = np.ones((H, W), dtype=bool) | |
x_ok = (0 <= pos[..., 0]) & (pos[..., 0] < W) | |
y_ok = (0 <= pos[..., 1]) & (pos[..., 1] < H) | |
vis &= x_ok & y_ok | |
m.fill(0) | |
ys, xs = np.where(vis) | |
px = np.round(pos[ys, xs, 0]).astype(int) | |
py = np.round(pos[ys, xs, 1]).astype(int) | |
inb = (0 <= px) & (px < W) & (0 <= py) & (py < H) | |
m[py[inb], px[inb]] = 1 | |
m = cv2.dilate(m, np.ones((2, 2), np.uint8)) | |
if m.sum() < area_thresh: | |
freeze_mask = True | |
frozen_mask = m.copy() | |
frames_since_corr = 0 | |
else: | |
m = frozen_mask | |
effective_mask = m.astype(bool) | |
all_masks.append(effective_mask) | |
out = cur.copy() | |
out[~effective_mask] = 0 | |
writer.append_data(out[:, :, ::-1]) | |
prev = cur if not freeze_mask else prev | |
writer.close() | |
cap.release() | |
all_masks_array = np.stack(all_masks, axis=0) | |
np.savez_compressed(mask_path, mask=all_masks_array) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--video_path", type=str, required=True) | |
parser.add_argument("--output_path", type=str, required=True) | |
parser.add_argument("--mask_path", type=str, required=True) | |
parser.add_argument("--raft_ckpt", type=str, required=True) | |
parser.add_argument("--start_idx", type=int, required=True) | |
parser.add_argument("--end_idx", type=int, required=True) | |
parser.add_argument("--gpu_id", type=int, required=True) | |
args = parser.parse_args() | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
os.makedirs(args.output_path, exist_ok=True) | |
os.makedirs(args.mask_path, exist_ok=True) | |
video_list = sorted([ | |
f for f in os.listdir(args.video_path) | |
if f.endswith(".mp4") | |
]) | |
selected_videos = video_list[args.start_idx : args.end_idx] | |
print(f"[GPU {args.gpu_id}] Processing {len(selected_videos)} videos: {args.start_idx} to {args.end_idx}") | |
model = load_raft_model(args.raft_ckpt) | |
for fname in tqdm(selected_videos, desc="Batch Processing"): | |
input_path = os.path.join(args.video_path, fname) | |
mask_path = os.path.join(args.mask_path, fname.replace(".mp4", ".npz")) | |
output_path = os.path.join(args.output_path, fname) | |
if os.path.exists(mask_path): | |
try: | |
np.load(mask_path)["mask"] | |
continue | |
except: | |
print(f"⚠️ Mask corrupt or unreadable: {mask_path} - Regenerating") | |
if os.path.exists(output_path): | |
continue | |
run_masking(input_path, output_path, mask_path, model) |