Spaces:
Runtime error
Runtime error
codellama-CodeLlama-7b-hf
/
Llama2-Code-Interpreter-main
/OpenCodeInterpreter
/evaluation
/evalplus
/tools
/tsr
/coverage_init.py
import os | |
import pickle | |
import sys | |
from importlib import import_module | |
from io import StringIO | |
from typing import Any, Dict, List | |
import coverage | |
from rich.progress import track | |
from evalplus.eval.utils import swallow_io | |
from tools.tsr.utils import get_problems, get_task_ids, to_path | |
class Capturing(list): | |
def __enter__(self): | |
self._stdout = sys.stdout | |
sys.stdout = self._stringio = StringIO() | |
return self | |
def __exit__(self, *args): | |
self.extend(self._stringio.getvalue().splitlines()) | |
del self._stringio | |
sys.stdout = self._stdout | |
def parse_lcov(outputs: List[str]): | |
switch, extracted_outputs = False, [] | |
for line in outputs: | |
if switch == False and "tmp_src" in line: | |
switch = True | |
if switch == True and "end_of_record" in line: | |
switch = False | |
if switch: | |
extracted_outputs.append(line) | |
branch, branch_covered = [], [] | |
for line in extracted_outputs: | |
if line.startswith("BRDA"): | |
# BRDA format: BR:<lineno>,<blockno>,<branchno>,<taken> | |
lineno, blockno, branchno, taken = line[5:].split(",") | |
branch_sig = f"BR:{lineno},{blockno},{branchno}" | |
branch.append(branch_sig) | |
if taken not in ["0", "-"]: | |
branch_covered.append(branch_sig) | |
per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch) | |
return per, branch, branch_covered | |
def test_code_coverage( | |
identifier: str, code: str, inputs: List[List[Any]], entry_point: str | |
): | |
module_name = f"tmp_src_{identifier}" | |
with open(f"{module_name}.py", "w") as f: | |
f.write(code) | |
mod = import_module(module_name) | |
func = getattr(mod, entry_point, None) | |
assert func != None, f"entry_point = {entry_point} not exist, code: {code}" | |
cov = coverage.Coverage(branch=True) | |
cov.start() | |
with swallow_io(): | |
for input_list in inputs: | |
func(*input_list) | |
cov.stop() | |
with Capturing() as outputs: | |
cov.lcov_report(outfile="-") | |
ret = parse_lcov(outputs) | |
os.remove(f"{module_name}.py") | |
return ret | |
def collect_coverage_info(coverage_dir: str, dataset: str) -> Dict[str, Dict[str, Any]]: | |
os.makedirs(coverage_dir, exist_ok=True) | |
problems = get_problems(dataset) | |
task_ids = get_task_ids(dataset) | |
coverage_info = {task_id: {} for task_id in task_ids} | |
for task_id in track(task_ids, description="Testing gt coverage..."): | |
coverage_cache_path = os.path.join(coverage_dir, f"{to_path(task_id)}.pkl") | |
if os.path.isfile(coverage_cache_path): | |
with open(coverage_cache_path, "rb") as f: | |
coverage_info[task_id] = pickle.load(f) | |
continue | |
groundtruth_code = ( | |
problems[task_id]["prompt"] + problems[task_id]["canonical_solution"] | |
) | |
plus_tests = problems[task_id]["plus_input"] | |
entry_point = problems[task_id]["entry_point"] | |
for i, plus_test in enumerate(plus_tests): | |
per, branch, branch_covered = test_code_coverage( | |
to_path(task_id), groundtruth_code, [plus_test], entry_point | |
) | |
test_id = f"plus_{i}" | |
coverage_info[task_id].setdefault(test_id, []).extend( | |
[(br, "gt") for br in branch_covered] | |
) | |
with open(coverage_cache_path, "wb") as f: | |
pickle.dump(coverage_info[task_id], f) | |
return coverage_info | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset", type=str, choices=["humaneval", "mbpp"]) | |
parser.add_argument("--report_dir", required=True, type=str) | |
args = parser.parse_args() | |
coverage_dir = os.path.join(args.report_dir, "coverage_cache") | |
collect_coverage_info(coverage_dir, args.dataset) | |