File size: 8,932 Bytes
4f08d2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from absl import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import sys
sys.path.insert(0, "/home/yash/ALIGN-SIM/src")
from utils import mkdir_p, full_path, read_data
from SentencePerturbation.word_replacer import WordReplacer, WordSwapping
import random
from perturbation_args import get_args
def perturb_sentences(dataset_name: str, task: str, target_lang:str ="en", output_dir: str = "./data/perturbed_dataset/", sample_size: int = 3500, save :str = False) -> None:
"""
perturb_sentences _summary_
Args:
dataset_name (str): ["MRPC","PAWS","QQP"]
task (str): ["Synonym","Antonym","Jumbling"]
target_lang (str, optional): _description_. Defaults to "en".
output_dir (str, optional): _description_. Defaults to "./data/perturbed_dataset/".
sample_size (int, optional): _description_. Defaults to 3500.
save (str, optional): _description_. Defaults to False.
"""
print("--------------------------------------")
output_csv = full_path(os.path.join(output_dir, target_lang, task, f"{dataset_name}_{task}_perturbed_{target_lang}.csv"))
if os.path.exists(output_csv):
print(f"File already exists at: {output_csv}")
return
# TODO: make it compatible with other language datasets
print("Loading dataset...")
data = read_data(dataset_name)
if "Unnamed: 0" in data.columns:
data.drop("Unnamed: 0", axis=1, inplace=True)
if "idx" in data.columns:
data.drop("idx", axis=1, inplace=True)
print(f"Loaded {dataset_name} dataset")
print("--------------------------------------")
# Initialize WordReplacer
replacer = WordReplacer()
# set seed
random.seed(42)
# Create a new dataframe to store perturbed sentences
# Sample sentences
perturbed_data = pd.DataFrame(columns=["original_sentence"])
# sample_data , pos_pairs, balance_dataset = sampling(data, sample_size)
if task in ["Syn","syn","Synonym"]:
print("Creating Synonym perturbed data...")
sample_data = sampling(data, task, sample_size)
perturbed_data["original_sentence"] = sample_data.sentence1
perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "synonyms"))
perturbed_data["perturb_n2"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 2, "synonyms"))
perturbed_data["perturb_n3"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 3, "synonyms"))
assert perturbed_data.shape[1] == 4, "Perturbed data size mismatch"
if task in ["paraphrase","Paraphrase","para"]:
print("Creating Paraphrase perturbed data...")
# shuffling the negative samples
# we also want equal number of positive and negative samples
perturbed_data = sampling(data, task, sample_size) # balance data
perturbed_data["original_sentence"] = perturbed_data.sentence1
perturbed_data["paraphrased_sentence"] = perturbed_data.sentence2
assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch" # original_sentence, paraphrased, label
if task in ["Anto","anto","Antonym"]:
print("Creating Antonym perturbed data...")
pos_pairs = sampling(data, task, sample_size)
# Apply antonym replacement
perturbed_data["original_sentence"] = pos_pairs.sentence1
perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2
perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "antonyms"))
assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch"
# Apply jumbling
if task in ["jumbling", "Jumbling","jumb"]:
print("Creating Jumbling perturbed data...")
pos_pairs = sampling(data, task, sample_size)
perturbed_data["original_sentence"] = pos_pairs.sentence1
perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2
perturbed_data["perturb_n1"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,1))
perturbed_data["perturb_n2"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,2))
perturbed_data["perturb_n3"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,3))
assert perturbed_data.shape[1] == 5, "Perturbed data size mismatch"
# Save to CSV
if save:
perturbed_data.to_csv(mkdir_p(output_csv), index=False)
print("--------------------------------------")
print(f"Saved at: {output_csv}")
print("--------------------------------------")
def sampling(data: pd.DataFrame, task :str, sample_size: int, random_state: int = 42):
"""
Combines two sampling strategies:
1. sampled_data: Samples from the dataset by first taking all positive pairs and then,
if needed, filling the remainder with negative pairs.
2. balanced_data: Constructs a dataset with roughly equal positive and negative pairs,
adjusting the numbers if one group is underrepresented.
Returns:
sampled_data (pd.DataFrame): Dataset sampled by filling negatives if positives are insufficient.
positive_data (pd.DataFrame): All positive samples (label == 1).
balanced_data (pd.DataFrame): Dataset balanced between positive and negative pairs.
"""
# Split the data into positive and negative pairs
positive_data = data[data["label"] == 1]
negative_data = data[data["label"] == 0]
if task in ["Anto","anto","Antonym","jumbling", "Jumbling","jumb"]:
return positive_data
# ----- Sampling positive pair, but also checking if we satisfy sample size -----
if sample_size is None or sample_size > len(positive_data):
# If no sample size is provided or it exceeds the available data,
# return a copy of the entire dataset.
sampled_data = positive_data.copy()
else:
# Otherwise, randomly sample the specified number of rows.
sampled_data = positive_data.sample(n=sample_size, random_state=random_state)
if task in ["Syn","syn","Synonym"]:
return sampled_data
# ----- Sampling for Paraphrased Criterion -----
# Shuffle negative pairs first
negative_data = negative_data.reset_index(drop=True)
shuffled_sentence2 = negative_data["sentence2"].sample(frac=1, random_state=random_state).reset_index(drop=True)
negative_data["sentence2"] = shuffled_sentence2
# Determine ideal sample size per group (half of total sample size)
if sample_size is None:
pos_sample_size = len(positive_data)
neg_sample_size = len(negative_data)
else:
# Determine ideal sample size per group (half of total sample size)
half_size = sample_size // 2
pos_available = len(positive_data)
neg_available = len(negative_data)
pos_sample_size = min(half_size, pos_available)
neg_sample_size = min(half_size, neg_available)
# If there is a remainder, add extra samples from the group with more available data.
total_sampled = pos_sample_size + neg_sample_size
remainder = sample_size - total_sampled
if remainder > 0:
if (pos_available - pos_sample_size) >= (neg_available - neg_sample_size):
pos_sample_size += remainder
else:
neg_sample_size += remainder
# Sample from each group
sampled_positive = positive_data.sample(n=pos_sample_size, random_state=random_state)
sampled_negative = negative_data.sample(n=neg_sample_size, random_state=random_state)
# Add a 'label' column
sampled_positive["label"] = 1
sampled_negative["label"] = 0
# Combine and shuffle the resulting dataset
balanced_data = pd.concat([sampled_positive, sampled_negative]).sample(frac=1, random_state=random_state).reset_index(drop=True)
if task in ["paraphrase","Paraphrase","para"]:
return balanced_data
# return sampled_data, positive_data, balanced_data
if __name__ == "__main__":
# # For Testing
if sys.gettrace() is not None:
config = {
"dataset_name": "mrpc",
"task": "syn",
"target_lang": "en",
"output_dir": "./data/perturbed_dataset/",
"save": True
}
else:
args = get_args()
config = {
"dataset_name": args.dataset_name,
"task": args.task,
"target_lang": args.target_lang,
"output_dir": args.output_dir,
"save": args.save,
"sample_size": args.sample_size
}
perturb_sentences(**config) |