File size: 8,932 Bytes
4f08d2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from absl import logging

import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import sys
sys.path.insert(0, "/home/yash/ALIGN-SIM/src")
from utils import mkdir_p, full_path, read_data
from SentencePerturbation.word_replacer import WordReplacer, WordSwapping
import random
from perturbation_args import get_args



def perturb_sentences(dataset_name: str, task: str, target_lang:str ="en", output_dir: str = "./data/perturbed_dataset/", sample_size: int = 3500, save :str = False) -> None:
    """
    perturb_sentences _summary_

    Args:
        dataset_name (str): ["MRPC","PAWS","QQP"]
        task (str): ["Synonym","Antonym","Jumbling"]
        target_lang (str, optional): _description_. Defaults to "en".
        output_dir (str, optional): _description_. Defaults to "./data/perturbed_dataset/".
        sample_size (int, optional): _description_. Defaults to 3500.
        save (str, optional): _description_. Defaults to False.
    """
    
    print("--------------------------------------")
    
    output_csv = full_path(os.path.join(output_dir, target_lang, task, f"{dataset_name}_{task}_perturbed_{target_lang}.csv"))
    if os.path.exists(output_csv):
        print(f"File already exists at: {output_csv}")
        return 
    
    # TODO: make it compatible with other language datasets
    print("Loading dataset...")
    data = read_data(dataset_name) 
    if "Unnamed: 0" in data.columns:
        data.drop("Unnamed: 0", axis=1, inplace=True)
    
    if "idx" in data.columns:
        data.drop("idx", axis=1, inplace=True)
        
    print(f"Loaded {dataset_name} dataset")
    
    print("--------------------------------------")

    
    # Initialize WordReplacer
    replacer = WordReplacer()
    # set seed
    random.seed(42)
    
    # Create a new dataframe to store perturbed sentences
    # Sample sentences
    perturbed_data = pd.DataFrame(columns=["original_sentence"])
    # sample_data , pos_pairs, balance_dataset  = sampling(data, sample_size)
    
    
    if task in ["Syn","syn","Synonym"]:
        print("Creating Synonym perturbed data...")
        sample_data = sampling(data, task, sample_size)
        perturbed_data["original_sentence"] = sample_data.sentence1
        perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "synonyms"))
        perturbed_data["perturb_n2"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 2, "synonyms"))
        perturbed_data["perturb_n3"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 3, "synonyms"))
        
        assert perturbed_data.shape[1] == 4, "Perturbed data size mismatch"
    
    if task in ["paraphrase","Paraphrase","para"]:
        print("Creating Paraphrase perturbed data...")
        # shuffling the negative samples
        # we also want equal number of positive and negative samples
        perturbed_data = sampling(data, task, sample_size) # balance data
        perturbed_data["original_sentence"] = perturbed_data.sentence1
        perturbed_data["paraphrased_sentence"] = perturbed_data.sentence2
        assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch" # original_sentence, paraphrased, label
        
    if task in ["Anto","anto","Antonym"]:
        print("Creating Antonym perturbed data...")
        pos_pairs = sampling(data, task, sample_size)
        # Apply antonym replacement
        perturbed_data["original_sentence"] = pos_pairs.sentence1
        perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2
        perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "antonyms"))
        assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch"
        
    # Apply jumbling
    if task in ["jumbling", "Jumbling","jumb"]:
        print("Creating Jumbling perturbed data...")
        pos_pairs = sampling(data, task, sample_size)
        perturbed_data["original_sentence"] = pos_pairs.sentence1
        perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2
        perturbed_data["perturb_n1"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,1))
        perturbed_data["perturb_n2"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,2))
        perturbed_data["perturb_n3"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,3))
        
        assert perturbed_data.shape[1] == 5, "Perturbed data size mismatch"
    # Save to CSV
    if save:
        perturbed_data.to_csv(mkdir_p(output_csv), index=False)
        print("--------------------------------------")
        print(f"Saved at: {output_csv}")
        print("--------------------------------------")



def sampling(data: pd.DataFrame, task :str, sample_size: int, random_state: int = 42):
    """
    Combines two sampling strategies:
    
    1. sampled_data: Samples from the dataset by first taking all positive pairs and then,
       if needed, filling the remainder with negative pairs.
    2. balanced_data: Constructs a dataset with roughly equal positive and negative pairs,
       adjusting the numbers if one group is underrepresented.
    
    Returns:
        sampled_data (pd.DataFrame): Dataset sampled by filling negatives if positives are insufficient.
        positive_data (pd.DataFrame): All positive samples (label == 1).
        balanced_data (pd.DataFrame): Dataset balanced between positive and negative pairs.
    """
    # Split the data into positive and negative pairs
    positive_data = data[data["label"] == 1]
    negative_data = data[data["label"] == 0]
    
    if task in ["Anto","anto","Antonym","jumbling", "Jumbling","jumb"]:
        return positive_data
    
    # ----- Sampling positive pair, but also checking if we satisfy sample size -----
    if sample_size is None or sample_size > len(positive_data):
        # If no sample size is provided or it exceeds the available data,
        # return a copy of the entire dataset.
        sampled_data = positive_data.copy()
    else:
        # Otherwise, randomly sample the specified number of rows.
        sampled_data = positive_data.sample(n=sample_size, random_state=random_state)

        
    if task in ["Syn","syn","Synonym"]:
        return sampled_data

    # ----- Sampling for Paraphrased Criterion -----
    # Shuffle negative pairs first
    negative_data = negative_data.reset_index(drop=True)
    shuffled_sentence2 = negative_data["sentence2"].sample(frac=1, random_state=random_state).reset_index(drop=True)
    negative_data["sentence2"] = shuffled_sentence2

    # Determine ideal sample size per group (half of total sample size)
    if sample_size is None:
        pos_sample_size = len(positive_data)
        neg_sample_size = len(negative_data)
    else:
        # Determine ideal sample size per group (half of total sample size)
        half_size = sample_size // 2
        pos_available = len(positive_data)
        neg_available = len(negative_data)
        pos_sample_size = min(half_size, pos_available)
        neg_sample_size = min(half_size, neg_available)

        # If there is a remainder, add extra samples from the group with more available data.
        total_sampled = pos_sample_size + neg_sample_size
        remainder = sample_size - total_sampled
        if remainder > 0:
            if (pos_available - pos_sample_size) >= (neg_available - neg_sample_size):
                pos_sample_size += remainder
            else:
                neg_sample_size += remainder

    # Sample from each group
    sampled_positive = positive_data.sample(n=pos_sample_size, random_state=random_state)
    sampled_negative = negative_data.sample(n=neg_sample_size, random_state=random_state)
    # Add a 'label' column
    sampled_positive["label"] = 1
    sampled_negative["label"] = 0
    # Combine and shuffle the resulting dataset
    balanced_data = pd.concat([sampled_positive, sampled_negative]).sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    if task in ["paraphrase","Paraphrase","para"]:
        return balanced_data
    # return sampled_data, positive_data, balanced_data



if __name__ == "__main__":

    # # For Testing
    if sys.gettrace() is not None:
        config = {
            "dataset_name": "mrpc",
            "task": "syn",
            "target_lang": "en",
            "output_dir": "./data/perturbed_dataset/",
            "save": True
        }
    else: 
        args = get_args()
        config = {
            "dataset_name": args.dataset_name,
            "task": args.task,
            "target_lang": args.target_lang,
            "output_dir": args.output_dir,
            "save": args.save,
            "sample_size": args.sample_size
        }
    perturb_sentences(**config)