# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from typing import Optional, Dict import contextlib import faulthandler import io import os import multiprocessing import platform import signal import tempfile def test_case_against_solution(prompt: str, completion: str, test: str, timeout: float): """ Get the output of a testcases against a solution """ def unsafe_execute(): with create_tempdir(): # These system calls are needed when cleaning up tempdir. import os import shutil rmtree = shutil.rmtree rmdir = os.rmdir chdir = os.chdir # Disable functionalities that can make destructive changes to the test. reliability_guard() # Construct the check program and run it. VERY IMPORTANT!!! import_packages_end_index = prompt.find("def") # some HumanEval prompts contain helper function that has # to be included. Patching. helper_func = "" helper_funcs_to_be_included = \ ["\n\ndef is_palindrome(string: str) -> bool:\n" " \"\"\" Test if given string is a palindrome \"\"\"\n" " return string == string[::-1]\n\n\n", "import math\n\n\ndef poly(xs: list, x: float):\n" " \"\"\"\n Evaluates polynomial with" " coefficients xs at point x.\n" " return xs[0] + xs[1] * x + xs[1] * x^2 + .... xs[n] * x^n\n" " \"\"\"\n return sum([coeff * math.pow(x, i) for i," " coeff in enumerate(xs)])\n\n\n", "\n\ndef encode_cyclic(s: str):\n" " \"\"\"\n returns encoded" " string by cycling groups of three" " characters.\n \"\"\"\n" " # split string to groups." " Each of length 3.\n" " groups" " = [s[(3 * i):min((3 * i + 3)," " len(s))] for" " i in range((len(s) + 2) // 3)]\n" " # cycle elements in each" " group. Unless group has fewer" " elements than 3.\n " "groups = [(group[1:] + group[0])" " if len(group) == 3 " "else group for group in groups]\n" " " "return \"\".join(groups)\n\n\n", "\n\ndef encode_shift(s: str):\n \"\"\"\n" " returns encoded string by shifting every" " character by 5 in the alphabet.\n \"\"\"\n" " return \"\".join([chr(((ord(ch)" " + 5 - ord(\"a\")) % 26) + ord(\"a\")) for ch in s])\n\n\n"] for helper in helper_funcs_to_be_included: if helper in prompt: helper_func = helper check_program = ( (prompt[:import_packages_end_index] if import_packages_end_index != -1 else "") + helper_func + completion + "\n" + f'result = {test}' ) try: exec_globals = {} with swallow_io(): with time_limit(timeout): exec(check_program, exec_globals) result.append(exec_globals['result']) except TimeoutException: result.append("FAILURE: timed out") except BaseException as e: error_type = type(e).__name__ result.append(f"FAILURE: {error_type}: {e}") # Needed for cleaning up. shutil.rmtree = rmtree os.rmdir = rmdir os.chdir = chdir manager = multiprocessing.Manager() result = manager.list() p = multiprocessing.Process(target=unsafe_execute) p.start() p.join(timeout=timeout+1) if p.is_alive(): p.kill() if not result: result.append("timed out") return result[0] @contextlib.contextmanager def time_limit(seconds: float): def signal_handler(signum, frame): raise TimeoutException("Timed out!") signal.setitimer(signal.ITIMER_REAL, seconds) signal.signal(signal.SIGALRM, signal_handler) try: yield finally: signal.setitimer(signal.ITIMER_REAL, 0) @contextlib.contextmanager def swallow_io(): stream = WriteOnlyStringIO() with contextlib.redirect_stdout(stream): with contextlib.redirect_stderr(stream): with redirect_stdin(stream): yield @contextlib.contextmanager def create_tempdir(): with tempfile.TemporaryDirectory() as dirname: with chdir(dirname): yield dirname class TimeoutException(Exception): pass class WriteOnlyStringIO(io.StringIO): """ StringIO that throws an exception when it's read from """ def read(self, *args, **kwargs): raise IOError def readline(self, *args, **kwargs): raise IOError def readlines(self, *args, **kwargs): raise IOError def readable(self, *args, **kwargs): """ Returns True if the IO object can be read. """ return False class redirect_stdin(contextlib._RedirectStream): # type: ignore _stream = 'stdin' @contextlib.contextmanager def chdir(root): if root == ".": yield return cwd = os.getcwd() os.chdir(root) try: yield except BaseException as exc: raise exc finally: os.chdir(cwd) def reliability_guard(maximum_memory_bytes: Optional[int] = None): """ This disables various destructive functions and prevents the generated code from interfering with the test (e.g. fork bomb, killing other processes, removing filesystem files, etc.) WARNING This function is NOT a security sandbox. Untrusted code, including, model- generated code, should not be blindly executed outside of one. See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution. """ if maximum_memory_bytes is not None: import resource resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) if not platform.uname().system == 'Darwin': resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) faulthandler.disable() import builtins builtins.exit = None builtins.quit = None import os os.environ['OMP_NUM_THREADS'] = '1' os.kill = None os.system = None os.putenv = None os.remove = None os.removedirs = None os.rmdir = None os.fchdir = None os.setuid = None os.fork = None os.forkpty = None os.killpg = None os.rename = None os.renames = None os.truncate = None os.replace = None os.unlink = None os.fchmod = None os.fchown = None os.chmod = None os.chown = None os.chroot = None os.fchdir = None os.lchflags = None os.lchmod = None os.lchown = None os.getcwd = None os.chdir = None import shutil shutil.rmtree = None shutil.move = None shutil.chown = None import subprocess subprocess.Popen = None # type: ignore __builtins__['help'] = None import sys sys.modules['ipdb'] = None sys.modules['joblib'] = None sys.modules['resource'] = None sys.modules['psutil'] = None sys.modules['tkinter'] = None