File size: 4,778 Bytes
17e77ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import random
from collections import defaultdict, deque

def generate_localization_samples(n):
    all_data = []
    global_index = 1

    def is_all_steps_connected(steps):
        # 构建依赖图
        graph = defaultdict(list)
        reverse_graph = defaultdict(list)
        all_ids = set()

        for step in steps:
            step_id = step["id"]
            inputs = step["inputs"]
            all_ids.add(step_id)
            for inp in inputs:
                if isinstance(inp, int):  # 如果引用了前一个 step
                    graph[inp].append(step_id)
                    reverse_graph[step_id].append(inp)

        # 最后一个 step ID
        print(steps)
        last_id = steps[-1]["id"]

        # 从最后一个 step 开始反向遍历,看能否覆盖所有 step
        visited = set()
        queue = deque([last_id])
        while queue:
            curr = queue.popleft()
            visited.add(curr)
            for parent in reverse_graph[curr]:
                if parent not in visited:
                    queue.append(parent)

        return all_ids.issubset(visited)

    while len(all_data) < n:
        sample = {"index": global_index, "instruction": "", "steps": []}
        num_locations = random.randint(1, 3)
        locations = [f"LOC_{i+1}" for i in range(num_locations)]
        used_locations = set()
        steps = []
        current_id = 1
        all_refs = locations.copy()  # step inputs can be LOCs or previous step IDs
        step_definitions = []

        num_steps = random.randint(2, 5)

        for _ in range(num_steps):
            func = random.choice(["Relative", "Azimuth", "Between"])
            if func in ["Relative", "Azimuth"]:
                base = random.choice(all_refs)
                if isinstance(base, str):
                    used_locations.add(base)

                if func == "Relative":
                    direction = random.choice([
                        "north", "south", "east", "west",
                        "northeast", "northwest", "southeast", "southwest"
                    ])
                    distance = f"{random.randint(1, 10)} km"
                    step_definitions.append({
                        "id": current_id,
                        "function": "Relative",
                        "inputs": [base, direction, distance]
                    })
                else:
                    angle = f"{random.randint(0, 359)}°"
                    distance = f"{random.randint(1, 10)} km"
                    step_definitions.append({
                        "id": current_id,
                        "function": "Azimuth",
                        "inputs": [base, angle, distance]
                    })

                all_refs.append(current_id)
                current_id += 1

            elif func == "Between" and len(all_refs) >= 2:
                base1, base2 = random.sample(all_refs, 2)
                for b in (base1, base2):
                    if isinstance(b, str):
                        used_locations.add(b)
                step_definitions.append({
                    "id": current_id,
                    "function": "Between",
                    "inputs": [base1, base2]
                })
                all_refs.append(current_id)
                current_id += 1

        if len(step_definitions) == 0:
            continue  # 无有效步骤,跳过重新生成

        all_locs_used = all(loc in used_locations for loc in locations)
        steps_connected = is_all_steps_connected(step_definitions)

        if all_locs_used and steps_connected:
            sample["steps"] = step_definitions
            all_data.append(sample)
            global_index += 1

        # 否则重新生成

    return all_data


def write_custom_json(data, filename):
    def format_step(step):
        inputs = json.dumps(step["inputs"], ensure_ascii=False)
        return f'{{"id": {step["id"]}, "function": "{step["function"]}", "inputs": {inputs}}}'

    with open(filename, "w", encoding="utf-8") as f:
        f.write("[\n")
        for i, item in enumerate(data):
            f.write("  {\n")
            f.write(f'    "index": {item["index"]},\n')
            f.write('    "instruction": "",\n')
            f.write('    "steps": [\n')
            step_lines = [f"      {format_step(step)}" for step in item["steps"]]
            f.write(",\n".join(step_lines))
            f.write("\n    ]\n")
            f.write("  }" + (",\n" if i < len(data) - 1 else "\n"))
        f.write("]\n")

# 运行
if __name__ == "__main__":
    samples = generate_localization_samples(100)
    write_custom_json(samples, "localization_samples.json")
    print("✅ Saved to localization_samples.json with all steps contributing.")