Factool / factool /code /helper /_execution.py
EQ3A2A's picture
Upload folder using huggingface_hub
d195d4f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Dict
import contextlib
import faulthandler
import io
import os
import multiprocessing
import platform
import signal
import tempfile
def test_case_against_solution(prompt: str, completion: str, test: str, timeout: float):
"""
Get the output of a testcases against a solution
"""
def unsafe_execute():
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
# Disable functionalities that can make destructive changes to the test.
reliability_guard()
# Construct the check program and run it. VERY IMPORTANT!!!
import_packages_end_index = prompt.find("def")
# some HumanEval prompts contain helper function that has
# to be included. Patching.
helper_func = ""
helper_funcs_to_be_included = \
["\n\ndef is_palindrome(string: str) -> bool:\n"
" \"\"\" Test if given string is a palindrome \"\"\"\n"
" return string == string[::-1]\n\n\n",
"import math\n\n\ndef poly(xs: list, x: float):\n"
" \"\"\"\n Evaluates polynomial with"
" coefficients xs at point x.\n"
" return xs[0] + xs[1] * x + xs[1] * x^2 + .... xs[n] * x^n\n"
" \"\"\"\n return sum([coeff * math.pow(x, i) for i,"
" coeff in enumerate(xs)])\n\n\n", "\n\ndef encode_cyclic(s: str):\n"
" \"\"\"\n returns encoded"
" string by cycling groups of three"
" characters.\n \"\"\"\n"
" # split string to groups."
" Each of length 3.\n"
" groups"
" = [s[(3 * i):min((3 * i + 3),"
" len(s))] for"
" i in range((len(s) + 2) // 3)]\n"
" # cycle elements in each"
" group. Unless group has fewer"
" elements than 3.\n "
"groups = [(group[1:] + group[0])"
" if len(group) == 3 "
"else group for group in groups]\n"
" "
"return \"\".join(groups)\n\n\n",
"\n\ndef encode_shift(s: str):\n \"\"\"\n"
" returns encoded string by shifting every"
" character by 5 in the alphabet.\n \"\"\"\n"
" return \"\".join([chr(((ord(ch)"
" + 5 - ord(\"a\")) % 26) + ord(\"a\")) for ch in s])\n\n\n"]
for helper in helper_funcs_to_be_included:
if helper in prompt:
helper_func = helper
check_program = (
(prompt[:import_packages_end_index] if import_packages_end_index != -1 else "") + helper_func + completion + "\n" + f'result = {test}'
)
try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
exec(check_program, exec_globals)
result.append(exec_globals['result'])
except TimeoutException:
result.append("FAILURE: timed out")
except BaseException as e:
error_type = type(e).__name__
result.append(f"FAILURE: {error_type}: {e}")
# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
manager = multiprocessing.Manager()
result = manager.list()
p = multiprocessing.Process(target=unsafe_execute)
p.start()
p.join(timeout=timeout+1)
if p.is_alive():
p.kill()
if not result:
result.append("timed out")
return result[0]
@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield
@contextlib.contextmanager
def create_tempdir():
with tempfile.TemporaryDirectory() as dirname:
with chdir(dirname):
yield dirname
class TimeoutException(Exception):
pass
class WriteOnlyStringIO(io.StringIO):
""" StringIO that throws an exception when it's read from """
def read(self, *args, **kwargs):
raise IOError
def readline(self, *args, **kwargs):
raise IOError
def readlines(self, *args, **kwargs):
raise IOError
def readable(self, *args, **kwargs):
""" Returns True if the IO object can be read. """
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'
@contextlib.contextmanager
def chdir(root):
if root == ".":
yield
return
cwd = os.getcwd()
os.chdir(root)
try:
yield
except BaseException as exc:
raise exc
finally:
os.chdir(cwd)
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
Codex paper for more information about OpenAI's code sandbox, and proceed
with caution.
"""
if maximum_memory_bytes is not None:
import resource
resource.setrlimit(resource.RLIMIT_AS,
(maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA,
(maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == 'Darwin':
resource.setrlimit(resource.RLIMIT_STACK,
(maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
import builtins
builtins.exit = None
builtins.quit = None
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.kill = None
os.system = None
os.putenv = None
os.remove = None
os.removedirs = None
os.rmdir = None
os.fchdir = None
os.setuid = None
os.fork = None
os.forkpty = None
os.killpg = None
os.rename = None
os.renames = None
os.truncate = None
os.replace = None
os.unlink = None
os.fchmod = None
os.fchown = None
os.chmod = None
os.chown = None
os.chroot = None
os.fchdir = None
os.lchflags = None
os.lchmod = None
os.lchown = None
os.getcwd = None
os.chdir = None
import shutil
shutil.rmtree = None
shutil.move = None
shutil.chown = None
import subprocess
subprocess.Popen = None # type: ignore
__builtins__['help'] = None
import sys
sys.modules['ipdb'] = None
sys.modules['joblib'] = None
sys.modules['resource'] = None
sys.modules['psutil'] = None
sys.modules['tkinter'] = None