qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
4.58 kB
import re
languge_settings = {
"python": {
"full_name": "Python",
"indent": 4,
},
"cpp": {
"full_name": "cpp",
"indent": 0,
"main": "int main()",
},
"java": {
"full_name": "Java",
"indent": 4,
"main": "public static void main",
},
"cs": {
"full_name": "csharp",
"indent": 0,
"main": "public static void Main",
},
"php": {
"full_name": "PHP",
"indent": 0,
},
"ts": {
"full_name": "TypeScript",
"indent": 0,
},
"js": {"full_name": "JavaScript", "indent": 0},
"sh": {"full_name": "Bash", "indent": 0},
}
def get_function_name(question: str, lang: str):
func_lines = [x for x in question.strip().split("\n") if x.strip()]
if lang.lower() == "python":
func_idx = [
i for i in range(len(func_lines)) if func_lines[i].startswith("def ")
][-1]
func_name = func_lines[func_idx].split("(")[0].strip()
func_prefix = "\n".join(func_lines[:func_idx])
return func_name, func_prefix
func_name = func_lines[-1].split("{")[0].strip()
func_prefix = "\n".join(func_lines[:-1])
return func_name, func_prefix
def extract_generation_code(example: str, lang_code: str, verbose: bool = False):
task_id = example["task_id"]
output = example.get("output", example.get("gpt_completion"))
question = example["prompt"].strip()
setting = languge_settings[lang_code]
lang = setting["full_name"]
indent = setting["indent"]
try:
code_block: str = re.findall(
f"```{lang.lower()}\n(.*?)```", output, re.DOTALL | re.IGNORECASE
)[0]
if verbose:
print(">>> Task: {}\n{}".format(task_id, code_block))
# Remove main
if setting.get("main", None) and setting["main"] in code_block:
main_start = code_block.index(setting["main"])
code_block = code_block[:main_start]
func_name, func_prefix = get_function_name(question, lang)
try:
start = code_block.lower().index(func_name.lower())
indent = 0
while start - indent >= 0 and code_block[start - indent - 1] == " ":
indent += 1
try:
end = code_block.rindex("\n" + " " * indent + "}")
except:
end = len(code_block)
except:
start = 0
try:
end = code_block.rindex("\n" + " " * indent + "}")
except:
end = len(code_block)
body = code_block[start:end]
if lang_code.lower() in ["php", "ts", "js"]:
body += "\n" + " " * indent + "}"
generation = func_prefix + "\n" + body + "\n"
example["generation"] = generation
except Exception as ex:
print(
"Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format(
ex, task_id, output
)
)
example["generation"] = example["prompt"] + "\n" + output
return example
def cleanup_code(
code: str,
language_type: str = None,
dataset: str = None,
issft: bool = False,
stop_words=[],
):
"""
Cleans up the generated code.
"""
if language_type.lower() == "python":
if issft:
code = _clean_python_code_for_sft(code)
stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
code = _truncate_code_at_stopwords(code, stop_words)
elif language_type.lower() == "ts":
code = _truncate_code_at_stopwords(
code,
stop_words
+ [
"\nexport",
"\nimport",
"\nexport default",
"\nimport default",
"\nconsole.log",
],
)
else:
code = _truncate_code_at_stopwords(code, stop_words)
return code
def _clean_python_code_for_sft(code):
code = code.replace("\r", "")
if "```python" in code:
code_start_idx = code.index("```python")
code = code[code_start_idx:].replace("```python", "").strip()
end_idx = code.find("```") if "```" in code else len(code)
code = code[:end_idx].strip()
return code
def _truncate_code_at_stopwords(code, stop_words):
min_stop_idx = len(code)
for stop_word in stop_words:
stop_index = code.find(stop_word)
if 0 <= stop_index < min_stop_idx:
min_stop_idx = stop_index
return code[:min_stop_idx]