File size: 5,418 Bytes
2df809d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python3
"""
Preprocess scenes by sorting images and generating image/video collections.

This script processes scenes in parallel using a thread pool, updating metadata
with sorted images, trajectories, intrinsics, and generating pair, image collection,
and video collection data. The processed metadata is saved to a new file in each scene directory.

Usage:
    python generate_set_arkitscenes.py --root /path/to/data --splits Training Test --max_interval 5.0 --num_workers 8
"""

import os
import os.path as osp
import argparse
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm


def get_timestamp(img_name):
    """
    Extract the timestamp from an image filename.
    Assumes the timestamp is the last underscore-separated token in the name (before the file extension).

    Args:
        img_name (str): The image filename.

    Returns:
        float: The extracted timestamp.
    """
    return float(img_name[:-4].split("_")[-1])


def process_scene(root, split, scene, max_interval):
    """
    Process a single scene by sorting its images by timestamp, updating trajectories,
    intrinsics, and pairings, and generating image/video collections.

    Args:
        root (str): Root directory of the dataset.
        split (str): The dataset split (e.g., 'Training', 'Test').
        scene (str): The scene identifier.
        max_interval (float): Maximum allowed time interval (in seconds) between images to consider them in the same video collection.
    """
    scene_dir = osp.join(root, split, scene)
    metadata_path = osp.join(scene_dir, "scene_metadata.npz")

    # Load the scene metadata
    with np.load(metadata_path) as data:
        images = data["images"]
        trajectories = data["trajectories"]
        intrinsics = data["intrinsics"]
        pairs = data["pairs"]

    # Sort images by timestep
    imgs_with_indices = sorted(enumerate(images), key=lambda x: x[1])
    indices, images = zip(*imgs_with_indices)
    indices = np.array(indices)
    index2sorted = {index: i for i, index in enumerate(indices)}

    # Reorder trajectories and intrinsics based on the new image order
    trajectories = trajectories[indices]
    intrinsics = intrinsics[indices]

    # Update pair indices (each pair is (id1, id2, score))
    pairs = [(index2sorted[id1], index2sorted[id2], score) for id1, id2, score in pairs]

    # Form image_collection: mapping from an image id to a list of (other image id, score)
    image_collection = {}
    for id1, id2, score in pairs:
        image_collection.setdefault(id1, []).append((id2, score))

    # Form video_collection: for each image, collect subsequent images within the max_interval time window
    video_collection = {}
    for i, image in enumerate(images):
        j = i + 1
        for j in range(i + 1, len(images)):
            if get_timestamp(images[j]) - get_timestamp(image) > max_interval:
                break
        video_collection[i] = list(range(i + 1, j))

    # Save the new metadata
    output_path = osp.join(scene_dir, "new_scene_metadata.npz")
    np.savez(
        output_path,
        images=images,
        trajectories=trajectories,
        intrinsics=intrinsics,
        pairs=pairs,
        image_collection=image_collection,
        video_collection=video_collection,
    )
    print(f"Processed scene: {scene}")


def main(args):
    """
    Main function to process scenes across specified dataset splits in parallel.
    """
    root = args.root
    splits = args.splits
    max_interval = args.max_interval
    num_workers = args.num_workers

    futures = []

    # Create a ThreadPoolExecutor for parallel processing
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for split in splits:
            all_meta_path = osp.join(root, split, "all_metadata.npz")
            with np.load(all_meta_path) as data:
                scenes = data["scenes"]

            # Submit processing tasks for each scene in the current split
            for scene in scenes:
                futures.append(
                    executor.submit(process_scene, root, split, scene, max_interval)
                )

        # Use tqdm to display a progress bar as futures complete
        for future in tqdm(
            as_completed(futures), total=len(futures), desc="Processing scenes"
        ):
            # This will raise any exceptions caught during scene processing.
            future.result()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Preprocess scene data to update metadata with sorted images and collections."
    )
    parser.add_argument(
        "--root",
        type=str,
        default="",
        help="Root directory containing the dataset splits.",
    )
    parser.add_argument(
        "--splits",
        type=str,
        nargs="+",
        default=["Training", "Test"],
        help="List of dataset splits to process (e.g., Training Test).",
    )
    parser.add_argument(
        "--max_interval",
        type=float,
        default=5.0,
        help="Maximum time interval (in seconds) between images to consider them in the same video sequence.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=8,
        help="Number of worker threads for parallel processing.",
    )
    args = parser.parse_args()
    main(args)