Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import multiprocessing | |
import os | |
import pickle | |
import threading | |
import time | |
from collections import Counter, defaultdict | |
from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED | |
from datetime import datetime | |
from typing import Any, Dict, List, Tuple | |
from warnings import warn | |
import numpy as np | |
from termcolor import cprint | |
from tqdm import tqdm | |
from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions | |
from bigcodebench.data.utils import CACHE_DIR | |
from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check | |
from bigcodebench.gen.util import trusted_check | |
Result = Tuple[str, List[bool]] | |
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit): | |
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") | |
if os.path.exists(cache_file): | |
if check_gt_only: | |
os.remove(cache_file) | |
else: | |
print(f"Load from ground-truth from {cache_file}") | |
with open(cache_file, "rb") as f: | |
return pickle.load(f) | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
print("\nAsserting the groundtruth...") | |
tbegin = time.time() | |
with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
futures = [] | |
n_samples = 0 | |
expected_time = dict() | |
for problem in problems.values(): | |
args = ( | |
problem["complete_prompt"] + "\n" + problem["canonical_solution"], | |
problem["test"], | |
problem["task_id"], | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
min_time_limit, | |
) | |
futures.append(executor.submit(trusted_check, *args)) | |
n_samples += 1 | |
for future in tqdm(as_completed(futures), total=n_samples): | |
result = future.result() | |
expected_time[result["task_id"]] = result["time"] | |
print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") | |
if any(expected_time.values()): | |
with open(cache_file, "wb") as f: | |
pickle.dump(expected_time, f) | |
return expected_time | |
def check_correctness( | |
completion_id: int, | |
problem: Dict[str, Any], | |
solution: str, | |
max_as_limit: float, | |
max_data_limit: float, | |
max_stack_limit: float, | |
identifier=None, | |
min_time_limit: float = 0.1, | |
gt_time_limit: float = 2.0, | |
) -> Dict[str, Result]: | |
ret = { | |
"completion_id": completion_id, | |
"task_id": problem["task_id"], | |
"_identifier": identifier, | |
"solution": solution, | |
} | |
ret["base"] = untrusted_check( | |
solution, | |
problem["test"], | |
problem["entry_point"], | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
min_time_limit, | |
gt_time_limit, | |
) | |
return ret | |
def evaluate( | |
split: str, | |
subset: str, | |
samples: str, | |
pass_k: str="1,5,10", | |
parallel: int = None, | |
min_time_limit: float = 1, | |
max_as_limit: int = 30 * 1024, | |
max_data_limit: int = 30 * 1024, | |
max_stack_limit: int = 10, | |
check_gt_only: bool = False, | |
no_gt: bool = False, | |
): | |
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()] | |
if parallel is None: | |
n_workers = max(1, multiprocessing.cpu_count() // 2) | |
else: | |
n_workers = parallel | |
if check_gt_only: | |
samples = "__dummy__.jsonl" | |
extra = subset + "_" if subset != "full" else "" | |
if os.path.isdir(samples): | |
result_path = os.path.join(samples, f"{extra}eval_results.json") | |
else: | |
assert samples.endswith(".jsonl") | |
result_path = samples.replace(".jsonl", f"_{extra}eval_results.json") | |
problems = get_bigcodebench(subset=subset) | |
dataset_hash = get_bigcodebench_hash(subset=subset) | |
if not no_gt: | |
expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit) | |
else: | |
expected_time = {task_id: None for task_id in problems} | |
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems]) | |
failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems] | |
if os.path.isfile(result_path): | |
with open(result_path, "r") as f: | |
results = json.load(f) | |
results = compatible_eval_result(results) | |
else: | |
if check_gt_only: | |
if gt_pass_rate > 0.99: | |
cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green") | |
else: | |
cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red") | |
if len(failed_tasks) > 0: | |
cprint(f"Failed tasks: {failed_tasks}", "red") | |
return {"gt_pass_rate":float(gt_pass_rate), "failed_tasks": failed_tasks} | |
results = { | |
"date": datetime.now().strftime("%Y-%m-%d %H:%M"), | |
"eval": {}, | |
} | |
with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
futures = [] | |
completion_id = Counter() | |
n_samples = 0 | |
eval_results = defaultdict(list) # task_id -> | |
remainings = set() | |
print("Reading samples...") | |
for sample in tqdm(load_solutions(samples)): | |
task_id = sample["task_id"] | |
if task_id not in problems: | |
warn( | |
f"Task {task_id} is found in the samples but not found in the dataset" | |
) | |
continue | |
solution = ( | |
sample["solution"] | |
if "solution" in sample | |
else problems[task_id]["complete_prompt"] + sample["completion"] | |
) | |
if "sanitized-calibrated" in samples: | |
solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution | |
remainings.add(sample["_identifier"]) | |
args = ( | |
completion_id[task_id], | |
problems[task_id], | |
solution, | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
sample["_identifier"], | |
min_time_limit, | |
expected_time[task_id] if expected_time[task_id] else 20 | |
) | |
futures.append(executor.submit(check_correctness, *args)) | |
completion_id[task_id] += 1 | |
n_samples += 1 | |
assert n_samples == len(remainings), "Missing problems in unfinished" | |
assert len(completion_id) == len(problems), "Missing problems in samples" | |
def stucking_checker(): | |
not_done = futures | |
while len(not_done) > 0: | |
done, not_done = wait(not_done, timeout=240, return_when=FIRST_COMPLETED) | |
if len(done) == 0: | |
warn("No samples have finished testing in the last 240s") | |
warn(f"{len(remainings)} samples to be tested: {remainings}") | |
threading.Thread(target=stucking_checker).start() | |
for future in tqdm(as_completed(futures), total=n_samples): | |
result = future.result() | |
remainings.remove(result["_identifier"]) | |
eval_results[result["task_id"]].append(result) | |
# sort the results for each problem by completion_id | |
for task_id, task_results in eval_results.items(): | |
task_results.sort(key=lambda x: x["completion_id"]) | |
results["eval"][task_id] = [] | |
for res in task_results: | |
stat, details = res["base"] | |
results["eval"][task_id].append( | |
{ | |
"task_id": task_id, | |
"solution": res["solution"], | |
"status": stat, | |
"details": details, | |
} | |
) | |
# Calculate pass@k. | |
total = np.array([len(r) for k, r in results["eval"].items() if k in problems]) | |
base_correct = [] | |
for key, res in results["eval"].items(): | |
if key not in problems: | |
continue | |
bc = sum([r["status"] == PASS for r in res]) | |
base_correct.append(bc) | |
base_correct = np.array(base_correct) | |
pass_at_k = { | |
f"pass@{k}": float(estimate_pass_at_k(total, base_correct, k).mean()) | |
for k in pass_k | |
if total.min() >= k | |
} | |
pass_at_k["gt_pass_rate"] = float(gt_pass_rate) | |
pass_at_k["failed_tasks"] = failed_tasks | |
return pass_at_k | |
# mode = "-calibrated" if "sanitized-calibrated" in samples else "" | |
# extra = subset.capitalize() | |
# split = split.capitalize() | |
# cprint(f"BigCodeBench-{split}{mode} ({extra})", "green") | |
# if no_gt: | |
# cprint(f"Groundtruth is not checked", "yellow") | |
# else: | |
# if gt_pass_rate > 0.99: | |
# cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green") | |
# else: | |
# cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red") | |
# if len(failed_tasks) > 0: | |
# cprint(f"Failed tasks: {failed_tasks}", "red") | |
# for k, v in pass_at_k.items(): | |
# cprint(f"{k}:\t{v:.3f}", "green") | |
# # save results | |
# if os.path.isfile(result_path): | |
# decision = "" | |
# while decision.lower() not in ["y", "n"]: | |
# print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") | |
# decision = input() | |
# if decision.lower() == "y": | |
# # mv the file to a backup | |
# new_path = result_path + ".bak" | |
# while os.path.isfile(new_path): | |
# new_path += ".bak" | |
# os.rename(result_path, new_path) | |
# print(f"Backup {result_path} to {new_path}") | |
# if not os.path.isfile(result_path): | |
# with open(result_path, "w") as f: | |
# json.dump(results, f, indent=2) | |
# if save_pass_rate: | |
# pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json") | |
# pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0] | |
# pass_at_k["calibrated"] = "sanitized-calibrated" in samples | |
# pass_at_k["subset"] = subset | |
# def save_pass_at_k(): | |
# with open(pass_at_k_path, "w") as f: | |
# json.dump(pass_at_k, f, indent=2) | |
# if os.path.isfile(pass_at_k_path): | |
# saved_pass_at_k = json.load(open(pass_at_k_path, "r")) | |
# # compare saved_pass_at_k with pass_at_k | |
# for k in saved_pass_at_k.keys(): | |
# if pass_at_k[k] != saved_pass_at_k[k]: | |
# cprint(f"Warning: {k} is different from the saved one", "yellow") | |
# # ask user whether to save the pass@k | |
# decision = "" | |
# while decision.lower() not in ["y", "n"]: | |
# print(f"Save pass@k to {pass_at_k_path}? [Y/N]") | |
# decision = input() | |
# if decision.lower() == "y": | |
# save_pass_at_k() | |
# else: | |
# save_pass_at_k() | |
def run_gradio(): | |
interface = gr.Interface( | |
fn=evaluate, | |
inputs=[ | |
gr.Dropdown(["complete", "instruct"], label="Split"), | |
gr.Dropdown(["full", "hard"], label="Subset"), | |
gr.File(label="Samples Path (.jsonl)"), | |
gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"), | |
gr.Slider(1, multiprocessing.cpu_count(), step=1, label="Parallel Workers"), | |
gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1), | |
gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024), | |
gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024), | |
gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10), | |
gr.Checkbox(label="Check GT Only"), | |
gr.Checkbox(label="No GT"), | |
], | |
outputs="text", | |
# concurrency_limit=None | |
) | |
interface.queue(default_concurrency_limit=None) | |
interface.launch(show_error=True) | |
if __name__ == "__main__": | |
run_gradio() | |
# evaluate("complete", "hard", "meta-llama--Llama-3.2-3B-Instruct--bigcodebench-instruct--vllm-0-1.jsonl") | |