Spaces:
Runtime error
Runtime error
File size: 2,308 Bytes
c87c295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import json
import os
import pickle
from rich.progress import track
from tools.tsr.utils import get_task_ids, to_path
def collect_sample_info(sample_dir: str, sample_eval_dir: str, dataset: str):
if os.path.exists(sample_dir) and len(os.listdir(sample_dir)) > 0:
# cache file exists
return
task_ids = get_task_ids(dataset)
assert os.path.exists(sample_eval_dir), "sample evaluation files missing"
os.makedirs(sample_dir, exist_ok=True)
kill_info = {task_id: {} for task_id in task_ids}
model_paths = os.listdir(sample_eval_dir)
for model_path in track(model_paths, description="Collecting sets..."):
if not model_path[-1].isdigit():
continue
eval_json_path = os.path.join(sample_eval_dir, model_path, "eval_results.json")
if not os.path.exists(eval_json_path):
continue
with open(eval_json_path, "r") as f:
res = json.load(f)["eval"]
for task_id, v in res.items():
if task_id not in task_ids:
continue
for i_code, (status, res_list) in enumerate(v["plus"]):
if status == "success":
continue
for i_test, res in enumerate(res_list):
test_id = f"plus_{i_test}"
if res == False:
if "_" in task_id:
task_id = task_id.replace("_", "/")
kill_info[task_id].setdefault(test_id, []).append(
(model_path, i_code)
)
for task_id in task_ids:
path_task_id = to_path(task_id)
with open(os.path.join(sample_dir, f"{path_task_id}.pkl"), "wb") as f:
pickle.dump(kill_info[task_id], f)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--report_dir", required=True, type=str)
parser.add_argument("--dataset", type=str, choices=["humaneval", "mbpp"])
parser.add_argument("--sample_eval_dir", required=True, type=str)
args = parser.parse_args()
sample_dir = os.path.join(args.report_dir, "sample_cache")
collect_sample_info(sample_dir, args.sample_eval_dir, args.dataset)
|