Spaces:
Runtime error
Runtime error
File size: 3,869 Bytes
c87c295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import pickle
import sys
from importlib import import_module
from io import StringIO
from typing import Any, Dict, List
import coverage
from rich.progress import track
from evalplus.eval.utils import swallow_io
from tools.tsr.utils import get_problems, get_task_ids, to_path
class Capturing(list):
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio
sys.stdout = self._stdout
def parse_lcov(outputs: List[str]):
switch, extracted_outputs = False, []
for line in outputs:
if switch == False and "tmp_src" in line:
switch = True
if switch == True and "end_of_record" in line:
switch = False
if switch:
extracted_outputs.append(line)
branch, branch_covered = [], []
for line in extracted_outputs:
if line.startswith("BRDA"):
# BRDA format: BR:<lineno>,<blockno>,<branchno>,<taken>
lineno, blockno, branchno, taken = line[5:].split(",")
branch_sig = f"BR:{lineno},{blockno},{branchno}"
branch.append(branch_sig)
if taken not in ["0", "-"]:
branch_covered.append(branch_sig)
per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch)
return per, branch, branch_covered
def test_code_coverage(
identifier: str, code: str, inputs: List[List[Any]], entry_point: str
):
module_name = f"tmp_src_{identifier}"
with open(f"{module_name}.py", "w") as f:
f.write(code)
mod = import_module(module_name)
func = getattr(mod, entry_point, None)
assert func != None, f"entry_point = {entry_point} not exist, code: {code}"
cov = coverage.Coverage(branch=True)
cov.start()
with swallow_io():
for input_list in inputs:
func(*input_list)
cov.stop()
with Capturing() as outputs:
cov.lcov_report(outfile="-")
ret = parse_lcov(outputs)
os.remove(f"{module_name}.py")
return ret
def collect_coverage_info(coverage_dir: str, dataset: str) -> Dict[str, Dict[str, Any]]:
os.makedirs(coverage_dir, exist_ok=True)
problems = get_problems(dataset)
task_ids = get_task_ids(dataset)
coverage_info = {task_id: {} for task_id in task_ids}
for task_id in track(task_ids, description="Testing gt coverage..."):
coverage_cache_path = os.path.join(coverage_dir, f"{to_path(task_id)}.pkl")
if os.path.isfile(coverage_cache_path):
with open(coverage_cache_path, "rb") as f:
coverage_info[task_id] = pickle.load(f)
continue
groundtruth_code = (
problems[task_id]["prompt"] + problems[task_id]["canonical_solution"]
)
plus_tests = problems[task_id]["plus_input"]
entry_point = problems[task_id]["entry_point"]
for i, plus_test in enumerate(plus_tests):
per, branch, branch_covered = test_code_coverage(
to_path(task_id), groundtruth_code, [plus_test], entry_point
)
test_id = f"plus_{i}"
coverage_info[task_id].setdefault(test_id, []).extend(
[(br, "gt") for br in branch_covered]
)
with open(coverage_cache_path, "wb") as f:
pickle.dump(coverage_info[task_id], f)
return coverage_info
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, choices=["humaneval", "mbpp"])
parser.add_argument("--report_dir", required=True, type=str)
args = parser.parse_args()
coverage_dir = os.path.join(args.report_dir, "coverage_cache")
collect_coverage_info(coverage_dir, args.dataset)
|