mamiksik's picture
Add t5predictor to app.py
83b862c
raw
history blame
2.72 kB
import re
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
def parse_files(accumulator: list[str], patch: str):
lines = patch.splitlines()
filename_before = None
for line in lines:
if line.startswith("index") or line.startswith("diff"):
continue
if line.startswith("---"):
filename_before = line.split(" ", 1)[1][1:]
continue
if line.startswith("+++"):
filename_after = line.split(" ", 1)[1][1:]
if filename_before == filename_after:
accumulator.append(f"<ide><path>{filename_before}")
else:
accumulator.append(f"<add><path>{filename_after}")
accumulator.append(f"<del><path>{filename_before}")
continue
line = re.sub("@@[^@@]*@@", "", line)
if len(line) == 0:
continue
if line[0] == "+":
line = line.replace("+", "<add>", 1)
elif line[0] == "-":
line = line.replace("-", "<del>", 1)
else:
line = f"<ide>{line}"
accumulator.append(line)
return accumulator
def predict(patch, max_length, min_length, num_beams, prediction_count):
accumulator = []
parse_files(accumulator, patch)
input_text = '\n'.join(accumulator)
with torch.no_grad():
token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1]
input_ids = tokenizer(
input_text,
truncation=True,
padding=True,
return_tensors="pt",
).input_ids
outputs = model.generate(
input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
num_return_sequences=prediction_count,
)
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return token_count, '\n'.join(accumulator), {k: 0 for k in result}
iface = gr.Interface(fn=predict, inputs=[
gr.Textbox(label="Patch (as generated by git diff)"),
gr.Slider(1, 128, value=20, label="Max message length"),
gr.Slider(1, 128, value=5, label="Min message length"),
gr.Slider(1, 10, value=7, label="Number of beams"),
gr.Slider(1, 15, value=5, label="Number of predictions"),
], outputs=[
gr.Textbox(label="Token count"),
gr.Textbox(label="Parsed patch"),
gr.Label(label="Predictions")
])
if __name__ == "__main__":
iface.launch()