Spaces:
Runtime error
Runtime error
from itertools import chain | |
import pandas as pd | |
from tqdm import tqdm | |
import config | |
import dataset_statistics | |
from api_wrappers import grazie_wrapper, hf_data_loader | |
from generation_steps import examples | |
GENERATION_MULTIPLIER = 3 | |
REL_INSERTIONS_THRESHOLD = 0.5 | |
GENERATION_ATTEMPTS = 3 | |
def build_prompt(reference, diff): | |
return f"""A software developer uses a LLM to generate commit messages. | |
They generated a commit message for the following source code changes: | |
START OF THE SOURCE CODE CHANGES | |
{diff} | |
END OF THE SOURCE CODE CHANGES | |
After generating the commit message the developer understands that it is not perfect. After making dome changes, | |
they come up with an edited version of the message. Here is this edited message: | |
START OF THE COMMIT MESSAGE | |
{reference} | |
END OF THE COMMIT MESSAGE | |
Your task is to print the initial, LLM-generated commit message. | |
The message you print must share some fragments with the edited message. | |
Here are some examples of what you should output: | |
START OF THE EXAMPLES LIST | |
{examples.EXAMPLES_END_TO_START} | |
END OF THE EXAMPLES LIST | |
Print only the initial commit message's text after the | |
token "OUTPUT". | |
OUTPUT""" | |
def generate_start_msg(end_msg, diff): | |
prompt = build_prompt(reference=end_msg, diff=diff) | |
results = [] | |
for i in range(GENERATION_ATTEMPTS): | |
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt) | |
stats = dataset_statistics.get_statistics_for_sample( | |
start_msg=start_msg_pred, | |
end_msg=end_msg, | |
) | |
if stats["insertions"] < REL_INSERTIONS_THRESHOLD: | |
return start_msg_pred | |
else: | |
results.append((stats["insertions"], start_msg_pred)) | |
results.sort() | |
return results[0][1] | |
COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"] | |
COLS_TO_DEFAULT = {"edit_time": None} | |
def transform(df): | |
print("End -> start synthesis:") | |
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}") | |
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}") | |
print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}") | |
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}") | |
df["end_to_start"] = False | |
generated_data = {"commit_msg_start": []} | |
for col in chain(COLS_TO_KEEP, COLS_TO_DEFAULT): | |
generated_data[col] = [] | |
for _, row in tqdm(df.iterrows(), total=len(df)): | |
for i in range(GENERATION_MULTIPLIER): | |
commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"], diff=row["mods"]) | |
generated_data["commit_msg_start"].append(commit_msg_start_pred) | |
for col in COLS_TO_KEEP: | |
generated_data[col].append(row[col]) | |
for col in COLS_TO_DEFAULT: | |
generated_data[col].append(COLS_TO_DEFAULT[col]) | |
generated_df = pd.DataFrame.from_dict(generated_data) | |
generated_df["end_to_start"] = True | |
result = pd.concat([df, generated_df], ignore_index=True) | |
result.to_csv(config.END_TO_START_ARTIFACT) | |
print("Done") | |
return result | |
def main(): | |
df = hf_data_loader.load_processed_rewriting_as_pandas() | |
transform(df) | |
if __name__ == "__main__": | |
main() | |