|
import pandas as pd |
|
import hashlib |
|
import os |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
def text_hash(text): |
|
"""Create a hash of the text after basic normalization""" |
|
|
|
text = str(text).strip().lower() |
|
|
|
text = ' '.join(text.split()) |
|
|
|
return hashlib.sha256(text.encode()).hexdigest() |
|
|
|
def remove_leaked_samples(train_path, val_path, test_path, output_dir='dataset/clean'): |
|
"""Remove overlapping samples between dataset splits""" |
|
print("\n=== Removing Data Leakage ===\n") |
|
|
|
|
|
hash_registry = defaultdict(set) |
|
splits = {} |
|
original_sizes = {} |
|
|
|
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print("Loading datasets...") |
|
splits = { |
|
'train': pd.read_csv(train_path), |
|
'val': pd.read_csv(val_path), |
|
'test': pd.read_csv(test_path) |
|
} |
|
|
|
|
|
for split_name, df in splits.items(): |
|
original_sizes[split_name] = len(df) |
|
print(f"Original {split_name} size: {len(df):,} samples") |
|
|
|
|
|
print("\nChecking for overlaps...") |
|
removed_counts = defaultdict(int) |
|
|
|
for split_name, df in splits.items(): |
|
print(f"\nProcessing {split_name} split...") |
|
|
|
|
|
current_hashes = set(df['comment_text'].apply(text_hash)) |
|
hash_registry[split_name] = current_hashes |
|
|
|
|
|
for other_split in splits: |
|
if other_split != split_name: |
|
if hash_registry[other_split]: |
|
overlaps = current_hashes & hash_registry[other_split] |
|
if overlaps: |
|
print(f" Found {len(overlaps):,} overlaps with {other_split}") |
|
|
|
df = df[~df['comment_text'].apply(text_hash).isin(overlaps)] |
|
removed_counts[f"{split_name}_from_{other_split}"] = len(overlaps) |
|
|
|
|
|
splits[split_name] = df |
|
|
|
|
|
print("\nSaving cleaned datasets...") |
|
for split_name, df in splits.items(): |
|
output_path = os.path.join(output_dir, f"{split_name}_clean.csv") |
|
df.to_csv(output_path, index=False) |
|
reduction = ((original_sizes[split_name] - len(df)) / original_sizes[split_name]) * 100 |
|
print(f"Cleaned {split_name}: {len(df):,} samples (-{reduction:.2f}%)") |
|
|
|
|
|
print("\nDetailed Overlap Statistics:") |
|
print("-" * 50) |
|
for overlap_type, count in removed_counts.items(): |
|
split_name, other_split = overlap_type.split('_from_') |
|
print(f"{split_name} → {other_split}: {count:,} overlapping samples removed") |
|
|
|
return splits |
|
|
|
def validate_cleaning(splits): |
|
"""Validate that no overlaps remain between splits""" |
|
print("\nValidating Cleaning...") |
|
print("-" * 50) |
|
|
|
all_clean = True |
|
for split1 in splits: |
|
for split2 in splits: |
|
if split1 < split2: |
|
hashes1 = set(splits[split1]['comment_text'].apply(text_hash)) |
|
hashes2 = set(splits[split2]['comment_text'].apply(text_hash)) |
|
overlaps = hashes1 & hashes2 |
|
if overlaps: |
|
print(f"⚠️ Warning: Found {len(overlaps)} overlaps between {split1} and {split2}") |
|
all_clean = False |
|
else: |
|
print(f"✅ No overlaps between {split1} and {split2}") |
|
|
|
if all_clean: |
|
print("\n✅ All splits are now clean with no overlaps!") |
|
else: |
|
print("\n⚠️ Some overlaps still remain. Consider additional cleaning.") |
|
|
|
if __name__ == "__main__": |
|
|
|
train_path = "dataset/split/train.csv" |
|
val_path = "dataset/split/val.csv" |
|
test_path = "dataset/split/test.csv" |
|
|
|
|
|
cleaned_splits = remove_leaked_samples(train_path, val_path, test_path) |
|
|
|
|
|
validate_cleaning(cleaned_splits) |