| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import T5ForConditionalGeneration, RobertaTokenizer | |
| tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01") | |
| model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01") | |
| def parse_files(accumulator: list[str], patch: str): | |
| lines = patch.splitlines() | |
| filename_before = None | |
| for line in lines: | |
| if line.startswith("index") or line.startswith("diff"): | |
| continue | |
| if line.startswith("---"): | |
| filename_before = line.split(" ", 1)[1][1:] | |
| continue | |
| if line.startswith("+++"): | |
| filename_after = line.split(" ", 1)[1][1:] | |
| if filename_before == filename_after: | |
| accumulator.append(f"<ide><path>{filename_before}") | |
| else: | |
| accumulator.append(f"<add><path>{filename_after}") | |
| accumulator.append(f"<del><path>{filename_before}") | |
| continue | |
| line = re.sub("@@[^@@]*@@", "", line) | |
| if len(line) == 0: | |
| continue | |
| if line[0] == "+": | |
| line = line.replace("+", "<add>", 1) | |
| elif line[0] == "-": | |
| line = line.replace("-", "<del>", 1) | |
| else: | |
| line = f"<ide>{line}" | |
| accumulator.append(line) | |
| return accumulator | |
| def predict(patch, max_length, min_length, num_beams, prediction_count): | |
| accumulator = [] | |
| parse_files(accumulator, patch) | |
| input_text = '\n'.join(accumulator) | |
| with torch.no_grad(): | |
| token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1] | |
| input_ids = tokenizer( | |
| input_text, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt", | |
| ).input_ids | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=num_beams, | |
| num_return_sequences=prediction_count, | |
| ) | |
| result = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| return token_count, '\n'.join(accumulator), {k: 0 for k in result} | |
| iface = gr.Interface(fn=predict, inputs=[ | |
| gr.Textbox(label="Patch (as generated by git diff)"), | |
| gr.Slider(1, 128, value=20, label="Max message length"), | |
| gr.Slider(1, 128, value=5, label="Min message length"), | |
| gr.Slider(1, 10, value=7, label="Number of beams"), | |
| gr.Slider(1, 15, value=5, label="Number of predictions"), | |
| ], outputs=[ | |
| gr.Textbox(label="Token count"), | |
| gr.Textbox(label="Parsed patch"), | |
| gr.Label(label="Predictions") | |
| ]) | |
| if __name__ == "__main__": | |
| iface.launch() | |