|
import re |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import T5ForConditionalGeneration, RobertaTokenizer |
|
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01") |
|
model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01") |
|
|
|
|
|
def parse_files(accumulator: list[str], patch: str): |
|
lines = patch.splitlines() |
|
|
|
filename_before = None |
|
for line in lines: |
|
if line.startswith("index") or line.startswith("diff"): |
|
continue |
|
if line.startswith("---"): |
|
filename_before = line.split(" ", 1)[1][1:] |
|
continue |
|
|
|
if line.startswith("+++"): |
|
filename_after = line.split(" ", 1)[1][1:] |
|
|
|
if filename_before == filename_after: |
|
accumulator.append(f"<ide><path>{filename_before}") |
|
else: |
|
accumulator.append(f"<add><path>{filename_after}") |
|
accumulator.append(f"<del><path>{filename_before}") |
|
continue |
|
|
|
line = re.sub("@@[^@@]*@@", "", line) |
|
if len(line) == 0: |
|
continue |
|
|
|
if line[0] == "+": |
|
line = line.replace("+", "<add>", 1) |
|
elif line[0] == "-": |
|
line = line.replace("-", "<del>", 1) |
|
else: |
|
line = f"<ide>{line}" |
|
|
|
accumulator.append(line) |
|
|
|
return accumulator |
|
|
|
|
|
def predict(patch, max_length, min_length, num_beams, prediction_count): |
|
accumulator = [] |
|
parse_files(accumulator, patch) |
|
input_text = '\n'.join(accumulator) |
|
|
|
with torch.no_grad(): |
|
token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1] |
|
|
|
input_ids = tokenizer( |
|
input_text, |
|
truncation=True, |
|
padding=True, |
|
return_tensors="pt", |
|
).input_ids |
|
|
|
outputs = model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
min_length=min_length, |
|
num_beams=num_beams, |
|
num_return_sequences=prediction_count, |
|
) |
|
|
|
result = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
return token_count, '\n'.join(accumulator), {k: 0 for k in result} |
|
|
|
|
|
iface = gr.Interface(fn=predict, inputs=[ |
|
gr.Textbox(label="Patch (as generated by git diff)"), |
|
gr.Slider(1, 128, value=20, label="Max message length"), |
|
gr.Slider(1, 128, value=5, label="Min message length"), |
|
gr.Slider(1, 10, value=7, label="Number of beams"), |
|
gr.Slider(1, 15, value=5, label="Number of predictions"), |
|
], outputs=[ |
|
gr.Textbox(label="Token count"), |
|
gr.Textbox(label="Parsed patch"), |
|
gr.Label(label="Predictions") |
|
]) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|