File size: 4,581 Bytes
2a26d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import re

languge_settings = {
    "python": {
        "full_name": "Python",
        "indent": 4,
    },
    "cpp": {
        "full_name": "cpp",
        "indent": 0,
        "main": "int main()",
    },
    "java": {
        "full_name": "Java",
        "indent": 4,
        "main": "public static void main",
    },
    "cs": {
        "full_name": "csharp",
        "indent": 0,
        "main": "public static void Main",
    },
    "php": {
        "full_name": "PHP",
        "indent": 0,
    },
    "ts": {
        "full_name": "TypeScript",
        "indent": 0,
    },
    "js": {"full_name": "JavaScript", "indent": 0},
    "sh": {"full_name": "Bash", "indent": 0},
}


def get_function_name(question: str, lang: str):
    func_lines = [x for x in question.strip().split("\n") if x.strip()]

    if lang.lower() == "python":
        func_idx = [
            i for i in range(len(func_lines)) if func_lines[i].startswith("def ")
        ][-1]
        func_name = func_lines[func_idx].split("(")[0].strip()
        func_prefix = "\n".join(func_lines[:func_idx])
        return func_name, func_prefix

    func_name = func_lines[-1].split("{")[0].strip()
    func_prefix = "\n".join(func_lines[:-1])
    return func_name, func_prefix


def extract_generation_code(example: str, lang_code: str, verbose: bool = False):
    task_id = example["task_id"]
    output = example.get("output", example.get("gpt_completion"))
    question = example["prompt"].strip()
    setting = languge_settings[lang_code]
    lang = setting["full_name"]
    indent = setting["indent"]

    try:
        code_block: str = re.findall(
            f"```{lang.lower()}\n(.*?)```", output, re.DOTALL | re.IGNORECASE
        )[0]
        if verbose:
            print(">>> Task: {}\n{}".format(task_id, code_block))

        # Remove main
        if setting.get("main", None) and setting["main"] in code_block:
            main_start = code_block.index(setting["main"])
            code_block = code_block[:main_start]

        func_name, func_prefix = get_function_name(question, lang)

        try:
            start = code_block.lower().index(func_name.lower())
            indent = 0
            while start - indent >= 0 and code_block[start - indent - 1] == " ":
                indent += 1

            try:
                end = code_block.rindex("\n" + " " * indent + "}")
            except:
                end = len(code_block)
        except:
            start = 0
            try:
                end = code_block.rindex("\n" + " " * indent + "}")
            except:
                end = len(code_block)

        body = code_block[start:end]

        if lang_code.lower() in ["php", "ts", "js"]:
            body += "\n" + " " * indent + "}"

        generation = func_prefix + "\n" + body + "\n"
        example["generation"] = generation

    except Exception as ex:
        print(
            "Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format(
                ex, task_id, output
            )
        )
        example["generation"] = example["prompt"] + "\n" + output

    return example


def cleanup_code(
    code: str,
    language_type: str = None,
    dataset: str = None,
    issft: bool = False,
    stop_words=[],
):
    """
    Cleans up the generated code.
    """

    if language_type.lower() == "python":
        if issft:
            code = _clean_python_code_for_sft(code)
        stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
        code = _truncate_code_at_stopwords(code, stop_words)
    elif language_type.lower() == "ts":
        code = _truncate_code_at_stopwords(
            code,
            stop_words
            + [
                "\nexport",
                "\nimport",
                "\nexport default",
                "\nimport default",
                "\nconsole.log",
            ],
        )
    else:
        code = _truncate_code_at_stopwords(code, stop_words)

    return code


def _clean_python_code_for_sft(code):
    code = code.replace("\r", "")
    if "```python" in code:
        code_start_idx = code.index("```python")
        code = code[code_start_idx:].replace("```python", "").strip()
        end_idx = code.find("```") if "```" in code else len(code)
        code = code[:end_idx].strip()

    return code


def _truncate_code_at_stopwords(code, stop_words):
    min_stop_idx = len(code)
    for stop_word in stop_words:
        stop_index = code.find(stop_word)
        if 0 <= stop_index < min_stop_idx:
            min_stop_idx = stop_index
    return code[:min_stop_idx]