ar08's picture
Upload 1040 files
246d201 verified
import logging
from utils import check_correctness
from evaluation.benchmarks.mint.tasks.base import Task
LOGGER = logging.getLogger('MINT')
class CodeGenTask(Task):
"""Generic code generation task instance."""
def __init__(self, id: str, prompt: str, reference: str, **kwargs):
super().__init__(**kwargs)
self._id = id
self._prompt = prompt
self._reference = reference
def success(self, solution: str) -> bool:
"""This checks whether the given solution can complete the current task.
Can be used to provides binary feedback.
"""
code_to_exec = self.extract_answer(solution)
LOGGER.debug(f'CODE_TO_EXEC:\n{code_to_exec}')
LOGGER.debug(f'TEST_CODE:\n{self._reference}')
res = check_correctness(
solution_code=code_to_exec, test_code=self._reference, timeout=10
)
return res['success']
class MBPPTask(CodeGenTask):
task_name = 'mbpp'
@property
def prompt(self) -> str:
"""Return the prompt for this task.
MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it.
"""
return self._prompt.replace('"""', '').strip()
def extract_answer(self, solution: str) -> str | None:
"""Extract the answer from the given solution.
Split off first block of code by scanning for class, def etc. on newlines.
Modified from:
https://github.com/bigcode-project/bigcode-evaluation-harness/blob/d61afde130005ecc65cf800ad8eca790a9bc2115/lm_eval/tasks/mbpp.py#L67
"""
# STOP_WORDS = ["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/"]
# return re.split("|".join(STOP_WORDS), solution)[0].rstrip()
return solution
class HumanEvalTask(CodeGenTask):
task_name = 'humaneval'
@property
def prompt(self) -> str:
"""Return the prompt for this task.
MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it.
"""
return 'Complete the following code:\n\n' + self._prompt
def extract_answer(self, solution: str) -> str | None:
"""Extract the answer from the given solution.
Split off first block of code by scanning for class, def etc. on newlines.
Modified from:
https://github.com/bigcode-project/bigcode-evaluation-harness/blob/d61afde130005ecc65cf800ad8eca790a9bc2115/lm_eval/tasks/humaneval.py#L56
"""
# STOP_WORDS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
# # Remove the last block of the code containing stop_words for HumanEval
# string_list = re.split("(%s)" % "|".join(STOP_WORDS), solution)
# # last string should be ""
# return "".join(string_list[:-2])
return solution