Spaces:
Runtime error
Runtime error
File size: 4,778 Bytes
17e77ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import json
import random
from collections import defaultdict, deque
def generate_localization_samples(n):
all_data = []
global_index = 1
def is_all_steps_connected(steps):
# 构建依赖图
graph = defaultdict(list)
reverse_graph = defaultdict(list)
all_ids = set()
for step in steps:
step_id = step["id"]
inputs = step["inputs"]
all_ids.add(step_id)
for inp in inputs:
if isinstance(inp, int): # 如果引用了前一个 step
graph[inp].append(step_id)
reverse_graph[step_id].append(inp)
# 最后一个 step ID
print(steps)
last_id = steps[-1]["id"]
# 从最后一个 step 开始反向遍历,看能否覆盖所有 step
visited = set()
queue = deque([last_id])
while queue:
curr = queue.popleft()
visited.add(curr)
for parent in reverse_graph[curr]:
if parent not in visited:
queue.append(parent)
return all_ids.issubset(visited)
while len(all_data) < n:
sample = {"index": global_index, "instruction": "", "steps": []}
num_locations = random.randint(1, 3)
locations = [f"LOC_{i+1}" for i in range(num_locations)]
used_locations = set()
steps = []
current_id = 1
all_refs = locations.copy() # step inputs can be LOCs or previous step IDs
step_definitions = []
num_steps = random.randint(2, 5)
for _ in range(num_steps):
func = random.choice(["Relative", "Azimuth", "Between"])
if func in ["Relative", "Azimuth"]:
base = random.choice(all_refs)
if isinstance(base, str):
used_locations.add(base)
if func == "Relative":
direction = random.choice([
"north", "south", "east", "west",
"northeast", "northwest", "southeast", "southwest"
])
distance = f"{random.randint(1, 10)} km"
step_definitions.append({
"id": current_id,
"function": "Relative",
"inputs": [base, direction, distance]
})
else:
angle = f"{random.randint(0, 359)}°"
distance = f"{random.randint(1, 10)} km"
step_definitions.append({
"id": current_id,
"function": "Azimuth",
"inputs": [base, angle, distance]
})
all_refs.append(current_id)
current_id += 1
elif func == "Between" and len(all_refs) >= 2:
base1, base2 = random.sample(all_refs, 2)
for b in (base1, base2):
if isinstance(b, str):
used_locations.add(b)
step_definitions.append({
"id": current_id,
"function": "Between",
"inputs": [base1, base2]
})
all_refs.append(current_id)
current_id += 1
if len(step_definitions) == 0:
continue # 无有效步骤,跳过重新生成
all_locs_used = all(loc in used_locations for loc in locations)
steps_connected = is_all_steps_connected(step_definitions)
if all_locs_used and steps_connected:
sample["steps"] = step_definitions
all_data.append(sample)
global_index += 1
# 否则重新生成
return all_data
def write_custom_json(data, filename):
def format_step(step):
inputs = json.dumps(step["inputs"], ensure_ascii=False)
return f'{{"id": {step["id"]}, "function": "{step["function"]}", "inputs": {inputs}}}'
with open(filename, "w", encoding="utf-8") as f:
f.write("[\n")
for i, item in enumerate(data):
f.write(" {\n")
f.write(f' "index": {item["index"]},\n')
f.write(' "instruction": "",\n')
f.write(' "steps": [\n')
step_lines = [f" {format_step(step)}" for step in item["steps"]]
f.write(",\n".join(step_lines))
f.write("\n ]\n")
f.write(" }" + (",\n" if i < len(data) - 1 else "\n"))
f.write("]\n")
# 运行
if __name__ == "__main__":
samples = generate_localization_samples(100)
write_custom_json(samples, "localization_samples.json")
print("✅ Saved to localization_samples.json with all steps contributing.")
|