File size: 4,264 Bytes
01c3073
9377434
7e17e4e
bbc1fe3
01c3073
 
d3eb07d
 
b1dd808
 
3f47af7
 
b1dd808
 
d3eb07d
1311a82
09910fb
84b4a86
762a224
4ee21c2
84b4a86
762a224
d3eb07d
9377434
b1dd808
0d90edb
 
 
 
01c3073
2d64e29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11f2ade
 
2d64e29
 
01c3073
6ab20b3
3f47af7
8ec4446
 
1311a82
 
 
 
 
 
fb5ce4e
1311a82
 
 
6ab20b3
 
 
 
3f47af7
0adad70
f183304
0adad70
f183304
bbc1fe3
0d90edb
 
 
0b155f0
4954b56
 
 
 
 
 
 
 
bbc1fe3
0b155f0
 
 
 
 
9377434
2d64e29
9377434
8e536a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbc1fe3
0b155f0
f3827d2
f183304
bbc1fe3
4954b56
 
 
 
d868a6b
1311a82
 
 
0d90edb
1311a82
4954b56
7e17e4e
0d90edb
4954b56
01c3073
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
from gradio_client import Client
import frontmatter
import os
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import huggingface_hub

import prep_decompiled

hf_key = os.environ["HF_TOKEN"]
huggingface_hub.login(token=hf_key)

tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
vardecoder_model = AutoModelForCausalLM.from_pretrained(
    "ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16#, device_map={"": 0}
).to("cuda")
fielddecoder_model = AutoModelForCausalLM.from_pretrained(
    "ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16#, device_map={"": 0}
).to("cuda")

gradio_client = Client("https://ejschwartz-resym-field-helper.hf.space/")

examples = [
    ex.encode().decode("unicode_escape") for ex in open("examples.txt", "r").readlines()
]


# Example prompt
#   "input": "```\n_BOOL8 __fastcall sub_409B9A(_QWORD *a1, _QWORD *a2)\n{\nreturn *a1 < *a2 || *a1 == *a2 && a1[1] < a2[1];\n}\n```\nWhat are the variable name and type for the following memory accesses:a1, a1[1], a2, a2[1]?\n",
#  "output": "a1: a, os_reltime* -> sec, os_time_t\na1[1]: a, os_reltime* -> usec, os_time_t\na2: b, os_reltime* -> sec, os_time_t\na2[1]: b, os_reltime* -> usec, os_time_t",
def field_prompt(code):
    field_helper_result = gradio_client.predict(
        decompiled_code=code,
        api_name="/predict",
    )
    print(f"field helper result: {field_helper_result}")

    fields = sorted([e['expr'] for e in field_helper_result[0] if e['expr'] != ''])
    print(f"fields: {fields}")

    prompt = f"```\n{code}\n```\nWhat are the variable name and type for the following memory accesses:{', '.join(fields)}?\n"

    print(f"field prompt: {prompt}")

    return prompt, field_helper_result

@spaces.GPU
def infer(code):

    splitcode = [s.strip() for s in code.splitlines()]
    code = "\n".join(splitcode)
    bodyvars = [
        v["name"] for v in prep_decompiled.extract_comments(splitcode) if "name" in v
    ]
    argvars = [
        v["name"] for v in prep_decompiled.parse_signature(splitcode) if "name" in v
    ]
    vars = argvars + bodyvars
    # comments = prep_decompiled.extract_comments(splitcode)
    # sig = prep_decompiled.parse_signature(splitcode)
    # print(f"vars {vars}")

    varstring = ", ".join([f"`{v}`" for v in vars])

    var_name = vars[0]

    # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
    var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{var_name}"

    print(f"Prompt:\n{var_prompt}")

    input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
        :, : 8192 - 1024
    ]
    var_output = vardecoder_model.generate(
        input_ids=input_ids,
        max_new_tokens=1024,
        num_beams=4,
        num_return_sequences=1,
        do_sample=False,
        early_stopping=False,
        pad_token_id=0,
        eos_token_id=0,
    )[0]
    var_output = tokenizer.decode(
        var_output[input_ids.size(1) :],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    field_prompt_result, field_helper_result = field_prompt(code)

    # field_output = fielddecoder_model.generate(
    #     input_ids=input_ids,
    #     max_new_tokens=1024,
    #     num_beams=4,
    #     num_return_sequences=1,
    #     do_sample=False,
    #     early_stopping=False,
    #     pad_token_id=0,
    #     eos_token_id=0,
    # )[0]
    # field_output = tokenizer.decode(
    #     field_output[input_ids.size(1) :],
    #     skip_special_tokens=True,
    #     clean_up_tokenization_spaces=True,
    # )

    var_output = var_name + ":" + var_output
    # field_output = var_name + ":" + field_output
    return var_output, varstring


demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Textbox(lines=10, value=examples[0], label="Hex-Rays Decompilation"),
    ],
    outputs=[
        gr.Text(label="Var Decoder Output"),
        # gr.Text(label="Field Decoder Output"),
        gr.Text(label="Generated Variable List"),
    ],
    description=frontmatter.load("README.md").content,
    examples=examples,
)
demo.launch()