Spaces:
Runtime error
Runtime error
import os, sys | |
import traceback | |
HUMAN_EVAL_PATH = os.path.join( | |
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), | |
"human-eval", | |
) | |
sys.path.append(HUMAN_EVAL_PATH) | |
from human_eval.data import write_jsonl, read_problems | |
from finetuning.conversation_template import msg_to_code_result_tok_temp | |
from code_interpreter.llama_hf import build_model_from_hf_path | |
from code_interpreter.LlamaCodeInterpreter import LlamaCodeInterpreter | |
from code_interpreter.GPTCodeInterpreter import GPTCodeInterpreter | |
from code_interpreter.RetrospectiveGPTCodeInterpreter import ( | |
RetrospectiveGPTCodeInterpreter, | |
) | |
import re | |
from rich import print | |
from rich.panel import Panel | |
from rich.syntax import Syntax | |
from rich.text import Text | |
from timeout_decorator import timeout | |
wrong = 0 | |
def extract_text(prompt, remove_lines=True): | |
token = '"""' | |
start = token | |
end = ">>>" | |
# end = '"""' | |
start_idx = prompt.find(start) + len(start) | |
end_idx = prompt.find(end) | |
output = prompt[start_idx:end_idx] | |
if remove_lines: | |
output = output.replace("\n", " ") | |
output = re.sub(r"\s+", " ", output).strip() | |
return output | |
def extract_all_code_block(input_str: str) -> str: | |
pattern = r"\[CODE_START_TOK\](.*?)\[/CODE_END_TOK\]" | |
matches = re.findall(pattern, input_str, re.DOTALL) | |
return "\n".join([match.strip() for match in matches]) if matches else None | |
def extract_all_code_block_gpt(input_str: str) -> str: | |
pattern = r"```python(.*?)```" | |
matches = re.findall(pattern, input_str, re.DOTALL) | |
return "\n".join([match.strip() for match in matches]) if matches else None | |
def delete_print_asser(code_text: str): | |
lines = code_text.split("\n") | |
new_lines = list() | |
for i in lines: | |
if i.strip().startswith("print("): | |
continue | |
new_lines.append(i) | |
new_code_text = "\n".join(new_lines) | |
return new_code_text | |
def extract_function_from_code_block(code_block: str) -> str: | |
lines = code_block.split("\n") | |
function_lines = [] | |
inside_function = False | |
for line in lines: | |
# Start extracting from function definition | |
if line.startswith("def "): | |
inside_function = True | |
# If we are inside the function, append lines | |
if inside_function: | |
function_lines.append(line) | |
# If we encounter an unindented line that isn't a comment and isn't the start of another function, stop. | |
if ( | |
not line.startswith(" ") | |
and not line.startswith("#") | |
and not line.startswith("def ") | |
): | |
break | |
# Remove trailing comments or blank lines and the last line which caused the exit from the loop | |
while function_lines and ( | |
function_lines[-1].strip() == "" | |
or function_lines[-1].strip().startswith("#") | |
or not function_lines[-1].startswith(" ") | |
): | |
function_lines.pop() | |
return "\n".join(function_lines) | |
def get_last_outermost_function_name(function_str): | |
matches = re.findall(r"^def (\w+)", function_str, re.MULTILINE) | |
if matches: | |
return matches[-1] # Return the last (outermost) function name | |
return "" | |
def get_last_function_name(function_str): | |
# Regular expression to match a function definition | |
matches = re.findall(r"def (\w+)", function_str) | |
if matches: | |
return matches[-1] # Return the last function name | |
return "" | |
def get_outermost_function_name(function_str): | |
matches = re.findall(r"^def (\w+)", function_str, re.MULTILINE) | |
if matches: | |
return matches[0] # Return the first (outermost) function name | |
return "" | |
def get_function_name(function_str): | |
# Regular expression to match a function definition | |
match = re.search(r"def (\w+)", function_str) | |
if match: | |
return match.group(0) | |
return "" | |
def extract_test_assertion(test_func: str): | |
test_cases = list() | |
for i in test_func.split("\n"): | |
if "assert" in i: | |
test_cases.append(i.strip()) | |
return ("\n".join(test_cases)).strip() | |
import_str = """ | |
import re | |
import math | |
from typing import List, Tuple, Optional | |
""" | |
def exec_with_timeout(import_str, full_test_code): | |
env = {**locals()} | |
code_to_exec = f"{import_str}\n{full_test_code}" | |
try: | |
exec(code_to_exec, env) | |
except Exception as e: | |
print(f"Error Type: {type(e).__name__}, Error Message: {e}") | |
return False # Return False if there's an error during execution | |
return True # Return True if executed without errors | |
if __name__ == "__main__": | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import argparse | |
parser = argparse.ArgumentParser(description="Process path for LLAMA2_FINETUNEED.") | |
parser.add_argument( | |
"--path", | |
type=str, | |
required=True, | |
help="Path to the finetuned LLAMA2 model.", | |
default='"./output/llama-2-7b-chat-ci"', | |
) | |
parser.add_argument( | |
"--model", | |
type=str, | |
required=False, | |
help="Path to the finetuned LLAMA2 model.", | |
default='"./output/llama-2-7b-chat-ci"', | |
) | |
parser.add_argument( | |
"--max-retry", | |
type=int, | |
required=False, | |
help="Maximum number of retries.", | |
default=5, # You can set any default value you want here. | |
) | |
args = parser.parse_args() | |
PROGRAMMING_PUZZLE_Q = True | |
problems = read_problems() | |
correct_total = 0 | |
total_problems = len(problems) | |
for idx, task_id in enumerate(problems): | |
if "gpt" not in args.model.lower(): | |
LLAMA2_FINETUNEED_PATH = args.path | |
interpreter = LlamaCodeInterpreter( | |
model_path=LLAMA2_FINETUNEED_PATH, | |
# load_in_4bit=True | |
) | |
else: | |
interpreter = RetrospectiveGPTCodeInterpreter( | |
model=args.model, | |
) | |
# dict_keys(['task_id', 'prompt', 'entry_point', 'canonical_solution', 'test']) | |
programming_puzzle = problems[task_id]["prompt"].replace(" ", "\t") | |
text_only_problem = extract_text(programming_puzzle) | |
interpreter.dialog = [ | |
{ | |
"role": "system", | |
"content": "You are helpful robot that can generate code , excute it and debug then answer", | |
} | |
] | |
if PROGRAMMING_PUZZLE_Q: | |
# programming puzzle | |
output_str = interpreter.chat( | |
user_message=f"Write a Python script to solve the following problem:\n{programming_puzzle}\nEnsure the solution is verified by printing the expected output.", | |
MAX_TRY=args.max_retry, | |
VERBOSE=True, | |
code_exec_prefix=f"\nfrom typing import List,Tuple\nimport math\n", | |
feedback_prompt="Ensure the output matches the expected result, taking into account any corner cases. If discrepancies arise, pinpoint where you went wrong. Then, refine the code to achieve the desired outcome.", | |
append_result=True, | |
)["content"] | |
else: | |
output_str = interpreter.chat( | |
user_message=f"Write a Python script for this problem:\n{text_only_problem}", | |
MAX_TRY=args.max_retry, | |
VERBOSE=True, | |
code_exec_prefix=f"\nfrom typing import List,Tuple\nimport math\n", | |
feedback_prompt="Ensure the output matches the expected result. If not tell where you got wrong, then refine the code to achieve the desired outcome.", | |
append_result=True, | |
)["content"] | |
function_str = "" | |
if "gpt" not in args.model.lower(): | |
code_block = extract_all_code_block(output_str) | |
else: | |
code_block = extract_all_code_block_gpt(output_str) | |
if (code_block is not None) and ("def" in code_block): | |
function_str = code_block | |
# function_name = get_last_outermost_function_name(function_str) | |
function_str = delete_print_asser(function_str) | |
function_name = get_last_outermost_function_name(function_str) | |
full_test_code = f"{function_str}\n#-----------\n{problems[task_id]['test']}\ncheck({function_name})" | |
# Print the full_test_code with syntax highlighting | |
syntax = Syntax( | |
# f"{programming_puzzle}\n{full_test_code}", | |
f"{full_test_code}", | |
"python", | |
theme="monokai", | |
line_numbers=True, | |
) | |
print(syntax) | |
is_correct = False # default is wrong | |
timeout_flag = False | |
try: | |
is_correct = exec_with_timeout(import_str, full_test_code) | |
except TimeoutError as e: | |
timeout_flag = True | |
print(f"Timeout with error msg : {e}") | |
if is_correct: | |
correct_total += 1 | |
acc = (correct_total) / (idx + 1) | |
# save dialog | |
interpreter.save_dialog( | |
path=f"./eval/gpt_humaneval_output/{task_id.replace('/','_')}_{is_correct}.json" | |
) | |
interpreter.close() | |
del interpreter | |
# Constructing the output | |
accuracy_text = Text( | |
f"Accuracy: {correct_total}/{idx+1}[{total_problems}] = {acc:.2%} [{is_correct}]", | |
style="bold blue", | |
) | |
panel = Panel(accuracy_text, title="Results", border_style="green") | |
print(panel) | |