File size: 4,342 Bytes
d187b57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import pandas as pd
import hashlib
import os
from collections import defaultdict
from pathlib import Path
def text_hash(text):
"""Create a hash of the text after basic normalization"""
# Convert to string and normalize
text = str(text).strip().lower()
# Remove extra whitespace
text = ' '.join(text.split())
# Create hash
return hashlib.sha256(text.encode()).hexdigest()
def remove_leaked_samples(train_path, val_path, test_path, output_dir='dataset/clean'):
"""Remove overlapping samples between dataset splits"""
print("\n=== Removing Data Leakage ===\n")
# Create hash registry
hash_registry = defaultdict(set)
splits = {}
original_sizes = {}
# Create output directory
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Load datasets
print("Loading datasets...")
splits = {
'train': pd.read_csv(train_path),
'val': pd.read_csv(val_path),
'test': pd.read_csv(test_path)
}
# Store original sizes
for split_name, df in splits.items():
original_sizes[split_name] = len(df)
print(f"Original {split_name} size: {len(df):,} samples")
# Process each split
print("\nChecking for overlaps...")
removed_counts = defaultdict(int)
for split_name, df in splits.items():
print(f"\nProcessing {split_name} split...")
# Calculate hashes for current split
current_hashes = set(df['comment_text'].apply(text_hash))
hash_registry[split_name] = current_hashes
# Check overlaps with other splits
for other_split in splits:
if other_split != split_name:
if hash_registry[other_split]: # Only check if other split is processed
overlaps = current_hashes & hash_registry[other_split]
if overlaps:
print(f" Found {len(overlaps):,} overlaps with {other_split}")
# Remove overlapping samples
df = df[~df['comment_text'].apply(text_hash).isin(overlaps)]
removed_counts[f"{split_name}_from_{other_split}"] = len(overlaps)
# Update splits dictionary with cleaned dataframe
splits[split_name] = df
# Save cleaned splits
print("\nSaving cleaned datasets...")
for split_name, df in splits.items():
output_path = os.path.join(output_dir, f"{split_name}_clean.csv")
df.to_csv(output_path, index=False)
reduction = ((original_sizes[split_name] - len(df)) / original_sizes[split_name]) * 100
print(f"Cleaned {split_name}: {len(df):,} samples (-{reduction:.2f}%)")
# Print detailed overlap statistics
print("\nDetailed Overlap Statistics:")
print("-" * 50)
for overlap_type, count in removed_counts.items():
split_name, other_split = overlap_type.split('_from_')
print(f"{split_name} → {other_split}: {count:,} overlapping samples removed")
return splits
def validate_cleaning(splits):
"""Validate that no overlaps remain between splits"""
print("\nValidating Cleaning...")
print("-" * 50)
all_clean = True
for split1 in splits:
for split2 in splits:
if split1 < split2: # Check each pair only once
hashes1 = set(splits[split1]['comment_text'].apply(text_hash))
hashes2 = set(splits[split2]['comment_text'].apply(text_hash))
overlaps = hashes1 & hashes2
if overlaps:
print(f"⚠️ Warning: Found {len(overlaps)} overlaps between {split1} and {split2}")
all_clean = False
else:
print(f"✅ No overlaps between {split1} and {split2}")
if all_clean:
print("\n✅ All splits are now clean with no overlaps!")
else:
print("\n⚠️ Some overlaps still remain. Consider additional cleaning.")
if __name__ == "__main__":
# Define paths
train_path = "dataset/split/train.csv"
val_path = "dataset/split/val.csv"
test_path = "dataset/split/test.csv"
# Remove leaked samples
cleaned_splits = remove_leaked_samples(train_path, val_path, test_path)
# Validate cleaning
validate_cleaning(cleaned_splits) |