Spaces:
Runtime error
Runtime error
| from itertools import chain | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import config | |
| import dataset_statistics | |
| from api_wrappers import grazie_wrapper, hf_data_loader | |
| from generation_steps import examples | |
| GENERATION_MULTIPLIER = 3 | |
| REL_INSERTIONS_THRESHOLD = 0.5 | |
| GENERATION_ATTEMPTS = 3 | |
| def build_prompt(reference, diff): | |
| return f"""A software developer uses a LLM to generate commit messages. | |
| They generated a commit message for the following source code changes: | |
| START OF THE SOURCE CODE CHANGES | |
| {diff} | |
| END OF THE SOURCE CODE CHANGES | |
| After generating the commit message the developer understands that it is not perfect. After making dome changes, | |
| they come up with an edited version of the message. Here is this edited message: | |
| START OF THE COMMIT MESSAGE | |
| {reference} | |
| END OF THE COMMIT MESSAGE | |
| Your task is to print the initial, LLM-generated commit message. | |
| The message you print must share some fragments with the edited message. | |
| Here are some examples of what you should output: | |
| START OF THE EXAMPLES LIST | |
| {examples.EXAMPLES_END_TO_START} | |
| END OF THE EXAMPLES LIST | |
| Print only the initial commit message's text after the | |
| token "OUTPUT". | |
| OUTPUT""" | |
| def generate_start_msg(end_msg, diff): | |
| prompt = build_prompt(reference=end_msg, diff=diff) | |
| results = [] | |
| for i in range(GENERATION_ATTEMPTS): | |
| start_msg_pred = grazie_wrapper.generate_for_prompt(prompt) | |
| stats = dataset_statistics.get_statistics_for_sample( | |
| start_msg=start_msg_pred, | |
| end_msg=end_msg, | |
| ) | |
| if stats["insertions"] < REL_INSERTIONS_THRESHOLD: | |
| return start_msg_pred | |
| else: | |
| results.append((stats["insertions"], start_msg_pred)) | |
| results.sort() | |
| return results[0][1] | |
| COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"] | |
| COLS_TO_DEFAULT = {"edit_time": None} | |
| def transform(df): | |
| print("End -> start synthesis:") | |
| print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}") | |
| print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}") | |
| print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}") | |
| print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}") | |
| df["end_to_start"] = False | |
| generated_data = {"commit_msg_start": []} | |
| for col in chain(COLS_TO_KEEP, COLS_TO_DEFAULT): | |
| generated_data[col] = [] | |
| for _, row in tqdm(df.iterrows(), total=len(df)): | |
| for i in range(GENERATION_MULTIPLIER): | |
| commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"], diff=row["mods"]) | |
| generated_data["commit_msg_start"].append(commit_msg_start_pred) | |
| for col in COLS_TO_KEEP: | |
| generated_data[col].append(row[col]) | |
| for col in COLS_TO_DEFAULT: | |
| generated_data[col].append(COLS_TO_DEFAULT[col]) | |
| generated_df = pd.DataFrame.from_dict(generated_data) | |
| generated_df["end_to_start"] = True | |
| result = pd.concat([df, generated_df], ignore_index=True) | |
| result.to_csv(config.END_TO_START_ARTIFACT) | |
| print("Done") | |
| return result | |
| def main(): | |
| df = hf_data_loader.load_processed_rewriting_as_pandas() | |
| transform(df) | |
| if __name__ == "__main__": | |
| main() | |