|
import re |
|
|
|
languge_settings = { |
|
"python": { |
|
"full_name": "Python", |
|
"indent": 4, |
|
}, |
|
"cpp": { |
|
"full_name": "cpp", |
|
"indent": 0, |
|
"main": "int main()", |
|
}, |
|
"java": { |
|
"full_name": "Java", |
|
"indent": 4, |
|
"main": "public static void main", |
|
}, |
|
"cs": { |
|
"full_name": "csharp", |
|
"indent": 0, |
|
"main": "public static void Main", |
|
}, |
|
"php": { |
|
"full_name": "PHP", |
|
"indent": 0, |
|
}, |
|
"ts": { |
|
"full_name": "TypeScript", |
|
"indent": 0, |
|
}, |
|
"js": {"full_name": "JavaScript", "indent": 0}, |
|
"sh": {"full_name": "Bash", "indent": 0}, |
|
} |
|
|
|
|
|
def get_function_name(question: str, lang: str): |
|
func_lines = [x for x in question.strip().split("\n") if x.strip()] |
|
|
|
if lang.lower() == "python": |
|
func_idx = [ |
|
i for i in range(len(func_lines)) if func_lines[i].startswith("def ") |
|
][-1] |
|
func_name = func_lines[func_idx].split("(")[0].strip() |
|
func_prefix = "\n".join(func_lines[:func_idx]) |
|
return func_name, func_prefix |
|
|
|
func_name = func_lines[-1].split("{")[0].strip() |
|
func_prefix = "\n".join(func_lines[:-1]) |
|
return func_name, func_prefix |
|
|
|
|
|
def extract_generation_code(example: str, lang_code: str, verbose: bool = False): |
|
task_id = example["task_id"] |
|
output = example.get("output", example.get("gpt_completion")) |
|
question = example["prompt"].strip() |
|
setting = languge_settings[lang_code] |
|
lang = setting["full_name"] |
|
indent = setting["indent"] |
|
|
|
try: |
|
code_block: str = re.findall( |
|
f"```{lang.lower()}\n(.*?)```", output, re.DOTALL | re.IGNORECASE |
|
)[0] |
|
if verbose: |
|
print(">>> Task: {}\n{}".format(task_id, code_block)) |
|
|
|
|
|
if setting.get("main", None) and setting["main"] in code_block: |
|
main_start = code_block.index(setting["main"]) |
|
code_block = code_block[:main_start] |
|
|
|
func_name, func_prefix = get_function_name(question, lang) |
|
|
|
try: |
|
start = code_block.lower().index(func_name.lower()) |
|
indent = 0 |
|
while start - indent >= 0 and code_block[start - indent - 1] == " ": |
|
indent += 1 |
|
|
|
try: |
|
end = code_block.rindex("\n" + " " * indent + "}") |
|
except: |
|
end = len(code_block) |
|
except: |
|
start = 0 |
|
try: |
|
end = code_block.rindex("\n" + " " * indent + "}") |
|
except: |
|
end = len(code_block) |
|
|
|
body = code_block[start:end] |
|
|
|
if lang_code.lower() in ["php", "ts", "js"]: |
|
body += "\n" + " " * indent + "}" |
|
|
|
generation = func_prefix + "\n" + body + "\n" |
|
example["generation"] = generation |
|
|
|
except Exception as ex: |
|
print( |
|
"Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format( |
|
ex, task_id, output |
|
) |
|
) |
|
example["generation"] = example["prompt"] + "\n" + output |
|
|
|
return example |
|
|
|
|
|
def cleanup_code( |
|
code: str, |
|
language_type: str = None, |
|
dataset: str = None, |
|
issft: bool = False, |
|
stop_words=[], |
|
): |
|
""" |
|
Cleans up the generated code. |
|
""" |
|
|
|
if language_type.lower() == "python": |
|
if issft: |
|
code = _clean_python_code_for_sft(code) |
|
stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] |
|
code = _truncate_code_at_stopwords(code, stop_words) |
|
elif language_type.lower() == "ts": |
|
code = _truncate_code_at_stopwords( |
|
code, |
|
stop_words |
|
+ [ |
|
"\nexport", |
|
"\nimport", |
|
"\nexport default", |
|
"\nimport default", |
|
"\nconsole.log", |
|
], |
|
) |
|
else: |
|
code = _truncate_code_at_stopwords(code, stop_words) |
|
|
|
return code |
|
|
|
|
|
def _clean_python_code_for_sft(code): |
|
code = code.replace("\r", "") |
|
if "```python" in code: |
|
code_start_idx = code.index("```python") |
|
code = code[code_start_idx:].replace("```python", "").strip() |
|
end_idx = code.find("```") if "```" in code else len(code) |
|
code = code[:end_idx].strip() |
|
|
|
return code |
|
|
|
|
|
def _truncate_code_at_stopwords(code, stop_words): |
|
min_stop_idx = len(code) |
|
for stop_word in stop_words: |
|
stop_index = code.find(stop_word) |
|
if 0 <= stop_index < min_stop_idx: |
|
min_stop_idx = stop_index |
|
return code[:min_stop_idx] |
|
|