|
import gzip |
|
import itertools |
|
import json |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from typing import * |
|
|
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
|
|
from human_eval.data import stream_jsonl |
|
from human_eval.execution import check_correctness |
|
|
|
IMPORT_HELPER = { |
|
"python": [ |
|
"import math", |
|
"import re", |
|
"import sys", |
|
"import copy", |
|
"import datetime", |
|
"import itertools", |
|
"import collections", |
|
"import heapq", |
|
"import functools", |
|
"import hashlib", |
|
"import numpy", |
|
"import numpy as np", |
|
"import string", |
|
"from typing import *", |
|
"from collections import *", |
|
], |
|
"go": [ |
|
"math", |
|
"strings", |
|
"fmt", |
|
"strconv", |
|
"time", |
|
"bytes", |
|
"regexp", |
|
"sort", |
|
"math/rand", |
|
"crypto/md5", |
|
], |
|
"cpp": [ |
|
"#include<stdlib.h>", |
|
"#include<algorithm>", |
|
"#include<math.h>", |
|
"#include<stdio.h>", |
|
"#include<vector>", |
|
"#include<string>", |
|
"#include<climits>", |
|
"#include<cstring>", |
|
"#include<iostream>", |
|
"#include<cassert>", |
|
], |
|
"cs": [ |
|
"using System.Numerics;", |
|
"using System.Diagnostics;", |
|
"using System.Collections.Generic;", |
|
"using System.Linq;", |
|
"using System.Text;", |
|
"using System.Security.Cryptography;", |
|
"using System.Collections.Generic;", |
|
], |
|
} |
|
|
|
|
|
LANGUAGE_NAME = { |
|
"cpp": "CPP", |
|
"go": "Go", |
|
"java": "Java", |
|
"js": "JavaScript", |
|
"python": "Python", |
|
} |
|
|
|
|
|
def read_dataset( |
|
data_file: str = None, |
|
dataset_type: str = "humaneval", |
|
num_shot=None, |
|
) -> Dict: |
|
""" |
|
Reads a dataset and returns a dictionary of tasks. |
|
""" |
|
if num_shot is not None: |
|
print(f"{num_shot}-shot setting...") |
|
if "humaneval" in dataset_type.lower(): |
|
if data_file is None: |
|
current_path = os.path.dirname(os.path.abspath(__file__)) |
|
data_file = os.path.join( |
|
current_path, |
|
"..", |
|
"humaneval-x", |
|
"python", |
|
"data", |
|
"humaneval_python.jsonl.gz", |
|
) |
|
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} |
|
else: |
|
raise f"Dataset: {dataset_type} not supported." |
|
|
|
return dataset |
|
|
|
|
|
def estimate_pass_at_k( |
|
num_samples: Union[int, List[int], np.ndarray], |
|
num_correct: Union[List[int], np.ndarray], |
|
k: int, |
|
) -> np.ndarray: |
|
""" |
|
Estimates pass@k of each problem and returns them in an array. |
|
""" |
|
|
|
def estimator(n: int, c: int, k: int) -> float: |
|
""" |
|
Calculates 1 - comb(n - c, k) / comb(n, k). |
|
""" |
|
if n - c < k: |
|
return 1.0 |
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) |
|
|
|
if isinstance(num_samples, int): |
|
num_samples_it = itertools.repeat(num_samples, len(num_correct)) |
|
else: |
|
assert len(num_samples) == len(num_correct) |
|
num_samples_it = iter(num_samples) |
|
|
|
return np.array( |
|
[estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] |
|
) |
|
|
|
|
|
def process_humaneval_test( |
|
sample, problems, example_test=False, is_mbpp=False, language="python" |
|
): |
|
""" |
|
Processes a sample for evaluation. |
|
""" |
|
task_id = sample["task_id"] |
|
if is_mbpp: |
|
return sample["generation"] + "\n" + "\n".join(problems[task_id]["test"]) |
|
|
|
prompt = sample["prompt"] |
|
if ( |
|
example_test |
|
and "example_test" in problems[task_id] |
|
and problems[task_id]["example_test"] != "" |
|
): |
|
test = problems[task_id]["example_test"] |
|
else: |
|
test = problems[task_id]["test"] |
|
code = sample["generation"] |
|
|
|
|
|
if language == "python": |
|
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" |
|
test_string = test_setup + code + "\n" + test + "\n" |
|
elif language == "cpp": |
|
test_set_up = "" |
|
for s in IMPORT_HELPER["cpp"]: |
|
if s not in prompt: |
|
test_set_up += s + "\n" |
|
test_string = test_set_up + "\n" + code + "\n" + test |
|
elif language == "java": |
|
test_string = code + "\n" + test |
|
elif language == "cs": |
|
test_set_up = "" |
|
for s in IMPORT_HELPER["cs"]: |
|
test_set_up += s + "\n" |
|
test_string = test_set_up + "\n" + code + "\n" + test |
|
elif language in ["js", "javascript", "ts", "sh", "go"]: |
|
test_string = code + "\n" + test |
|
elif language == "go232": |
|
import_string = problems[task_id]["import"] |
|
prompt = prompt.replace(import_string, "") |
|
if example_test and "example_test" in problems[task_id]: |
|
test = problems[task_id]["example_test"] |
|
else: |
|
test = problems[task_id]["test"] |
|
test_setup = problems[task_id]["test_setup"] |
|
other_pkgs = [] |
|
for pkg in IMPORT_HELPER["go"]: |
|
if pkg not in test_setup: |
|
p = pkg.split("/")[-1] |
|
if p + "." in code: |
|
other_pkgs.append(f'"{pkg}"') |
|
if other_pkgs: |
|
import_other_pkgs = ( |
|
"import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" |
|
) |
|
test_string = ( |
|
test_setup |
|
+ "\n" |
|
+ import_other_pkgs |
|
+ "\n" |
|
+ prompt |
|
+ code |
|
+ "\n" |
|
+ test |
|
) |
|
else: |
|
test_string = test_setup + "\n" + prompt + code + "\n" + test |
|
elif language == "rust": |
|
main = "\nfn main(){ \n } \n" |
|
declaration = problems[task_id]["declaration"] |
|
test_string = main + declaration + prompt + code + test |
|
elif language == "php": |
|
if code[:5] != "<?php": |
|
code = "<?php\n" + code |
|
test_string = code + "\n" + test + "?>" |
|
return test_string |
|
|
|
|
|
def stream_jsonl_all(filename: str) -> Iterable[Dict]: |
|
""" |
|
Streams a JSONL file. |
|
""" |
|
results = [] |
|
if filename.endswith(".gz"): |
|
fp = gzip.open(open(filename, "rb"), "rt") |
|
else: |
|
fp = open(filename, "r") |
|
for line in fp: |
|
if any(not x.isspace() for x in line): |
|
results.append(json.loads(line)) |
|
fp.close() |
|
|
|
return results |
|
|
|
|
|
def evaluate_functional_correctness( |
|
input_file: str = None, |
|
tmp_dir: str = "./", |
|
n_workers: int = 32, |
|
timeout: float = 10.0, |
|
problem_file: str = "../data/humaneval_python.jsonl.gz", |
|
out_path: str = None, |
|
k: List[int] = [1, 10, 100], |
|
test_groundtruth: bool = False, |
|
example_test: bool = False, |
|
is_mbpp: bool = False, |
|
language: str = "python", |
|
): |
|
""" |
|
Evaluates the functional correctness of a model. |
|
""" |
|
if example_test: |
|
print("Example test...") |
|
|
|
problems = read_dataset(problem_file, dataset_type="humaneval") |
|
sample_jsonl = stream_jsonl_all(input_file) |
|
|
|
with ThreadPoolExecutor(max_workers=n_workers) as executor: |
|
|
|
futures = [] |
|
completion_id = Counter() |
|
n_samples = 0 |
|
|
|
results = {} |
|
|
|
if test_groundtruth: |
|
print("Testing ground truth...") |
|
for sample in tqdm(problems.values()): |
|
task_id = sample["task_id"] |
|
lang = task_id.split("/")[0].lower() |
|
if lang == "javascript": |
|
lang = "js" |
|
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") |
|
sample["generation"] = sample["canonical_solution"] |
|
sample["test_code"] = process_humaneval_test( |
|
sample, problems, example_test, language |
|
) |
|
if sample["test_code"] is None: |
|
continue |
|
args = ( |
|
task_id, |
|
sample, |
|
lang, |
|
timeout, |
|
tmp_dir_, |
|
completion_id[task_id], |
|
) |
|
future = executor.submit(check_correctness, *args) |
|
futures.append(future) |
|
completion_id[task_id] += 1 |
|
n_samples += 1 |
|
else: |
|
print("Reading samples...") |
|
for sample in tqdm(sample_jsonl): |
|
task_id = sample["task_id"] |
|
if not is_mbpp: |
|
lang = language |
|
if not is_mbpp and lang == "javascript": |
|
lang = "js" |
|
if is_mbpp: |
|
lang = "python" |
|
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") |
|
sample["task_id"] = task_id |
|
sample["test_code"] = process_humaneval_test( |
|
sample, problems, example_test, is_mbpp, language |
|
) |
|
if sample["test_code"] is None: |
|
continue |
|
if "completion_id" in sample: |
|
completion_id_ = sample["completion_id"] |
|
else: |
|
completion_id_ = completion_id[task_id] |
|
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_) |
|
future = executor.submit(check_correctness, *args) |
|
futures.append(future) |
|
completion_id[task_id] += 1 |
|
n_samples += 1 |
|
|
|
if len(completion_id) == len(problems): |
|
evaluate_pass_at_k = True |
|
else: |
|
evaluate_pass_at_k = False |
|
|
|
print("Running test suites...") |
|
for future in tqdm(as_completed(futures), total=len(futures)): |
|
result = future.result() |
|
|
|
results[result["task_id"]] = result |
|
|
|
|
|
total, correct = [], [] |
|
for result in results.values(): |
|
|
|
passed = [result["passed"]] |
|
total.append(len(passed)) |
|
correct.append(sum(passed)) |
|
total = np.array(total) |
|
correct = np.array(correct) |
|
|
|
if evaluate_pass_at_k: |
|
ks = k |
|
pass_at_k = { |
|
f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() |
|
for k in ks |
|
if (total >= k).all() |
|
} |
|
print(pass_at_k) |
|
else: |
|
print("Total:", np.sum(total)) |
|
print("Correct:", np.sum(correct)) |
|
|
|
if out_path: |
|
with open(out_path, "w") as f: |
|
json.dump(list(results.values()), f, ensure_ascii=False) |
|
|
|
return pass_at_k |
|
|