|
|
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.model_selection import StratifiedKFold |
|
from pathlib import Path |
|
import json |
|
from collections import defaultdict |
|
import logging |
|
from typing import Dict, Tuple, Set |
|
import time |
|
from itertools import combinations |
|
import hashlib |
|
from tqdm import tqdm |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
|
|
TOXICITY_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
|
RARE_CLASSES = ['threat', 'identity_hate'] |
|
MIN_SAMPLES_PER_CLASS = 1000 |
|
|
|
def create_multilabel_stratification_labels(row: pd.Series) -> str: |
|
""" |
|
Create composite labels that preserve multi-label patterns and language distribution. |
|
Uses iterative label combination to capture co-occurrence patterns. |
|
""" |
|
|
|
label = str(row['lang']) |
|
|
|
|
|
for col in TOXICITY_COLUMNS: |
|
label += '_' + str(int(row[col])) |
|
|
|
|
|
for c1, c2 in combinations(RARE_CLASSES, 2): |
|
co_occur = int(row[c1] == 1 and row[c2] == 1) |
|
label += '_' + str(co_occur) |
|
|
|
return label |
|
|
|
def oversample_rare_classes(df: pd.DataFrame) -> pd.DataFrame: |
|
""" |
|
Perform intelligent oversampling of rare classes while maintaining language distribution. |
|
""" |
|
oversampled_dfs = [] |
|
original_df = df.copy() |
|
|
|
|
|
for lang in df['lang'].unique(): |
|
lang_df = df[df['lang'] == lang] |
|
|
|
for rare_class in RARE_CLASSES: |
|
class_samples = lang_df[lang_df[rare_class] == 1] |
|
target_samples = MIN_SAMPLES_PER_CLASS |
|
|
|
if len(class_samples) < target_samples: |
|
|
|
n_samples = target_samples - len(class_samples) |
|
|
|
|
|
noise = np.random.normal(0, 0.1, (n_samples, len(TOXICITY_COLUMNS))) |
|
oversampled = class_samples.sample(n_samples, replace=True) |
|
|
|
|
|
for col in TOXICITY_COLUMNS: |
|
if col in [rare_class] + [c for c in RARE_CLASSES if c != rare_class]: |
|
continue |
|
oversampled[col] = np.clip( |
|
oversampled[col].values + noise[:, TOXICITY_COLUMNS.index(col)], |
|
0, 1 |
|
) |
|
|
|
oversampled_dfs.append(oversampled) |
|
|
|
if oversampled_dfs: |
|
return pd.concat([original_df] + oversampled_dfs, axis=0).reset_index(drop=True) |
|
return original_df |
|
|
|
def verify_distributions( |
|
original_df: pd.DataFrame, |
|
train_df: pd.DataFrame, |
|
val_df: pd.DataFrame, |
|
test_df: pd.DataFrame = None |
|
) -> Dict: |
|
""" |
|
Enhanced verification of distributions across splits with detailed metrics. |
|
""" |
|
splits = { |
|
'original': original_df, |
|
'train': train_df, |
|
'val': val_df |
|
} |
|
if test_df is not None: |
|
splits['test'] = test_df |
|
|
|
stats = defaultdict(dict) |
|
|
|
for split_name, df in splits.items(): |
|
|
|
stats[split_name]['language_dist'] = df['lang'].value_counts(normalize=True).to_dict() |
|
|
|
|
|
lang_class_dist = {} |
|
for lang in df['lang'].unique(): |
|
lang_df = df[df['lang'] == lang] |
|
lang_class_dist[lang] = { |
|
col: { |
|
'positive_ratio': lang_df[col].mean(), |
|
'count': int(lang_df[col].sum()), |
|
'total': len(lang_df) |
|
} for col in TOXICITY_COLUMNS |
|
} |
|
stats[split_name]['lang_class_dist'] = lang_class_dist |
|
|
|
|
|
cooccurrence = {} |
|
for c1, c2 in combinations(TOXICITY_COLUMNS, 2): |
|
cooccur_count = ((df[c1] == 1) & (df[c2] == 1)).sum() |
|
cooccurrence[f"{c1}_{c2}"] = { |
|
'count': int(cooccur_count), |
|
'ratio': float(cooccur_count) / len(df) |
|
} |
|
stats[split_name]['cooccurrence_patterns'] = cooccurrence |
|
|
|
|
|
if split_name != 'original': |
|
deltas = {} |
|
for lang in df['lang'].unique(): |
|
for col in TOXICITY_COLUMNS: |
|
orig_ratio = splits['original'][splits['original']['lang'] == lang][col].mean() |
|
split_ratio = df[df['lang'] == lang][col].mean() |
|
deltas[f"{lang}_{col}"] = abs(orig_ratio - split_ratio) |
|
stats[split_name]['distribution_deltas'] = deltas |
|
|
|
return stats |
|
|
|
def check_contamination( |
|
train_df: pd.DataFrame, |
|
val_df: pd.DataFrame, |
|
test_df: pd.DataFrame = None |
|
) -> Dict: |
|
""" |
|
Enhanced contamination check including text similarity detection. |
|
""" |
|
|
|
text_column = 'comment_text' if 'comment_text' in train_df.columns else 'text' |
|
if text_column not in train_df.columns: |
|
logging.warning("No text column found for contamination check. Skipping text-based contamination detection.") |
|
return {'exact_matches': {'train_val': 0.0}} |
|
|
|
def get_text_hash_set(df: pd.DataFrame) -> Set[str]: |
|
return set(df[text_column].str.lower().str.strip().values) |
|
|
|
contamination = { |
|
'exact_matches': { |
|
'train_val': len(get_text_hash_set(train_df) & get_text_hash_set(val_df)) / len(train_df) |
|
} |
|
} |
|
|
|
if test_df is not None: |
|
contamination['exact_matches'].update({ |
|
'train_test': len(get_text_hash_set(train_df) & get_text_hash_set(test_df)) / len(train_df), |
|
'val_test': len(get_text_hash_set(val_df) & get_text_hash_set(test_df)) / len(val_df) |
|
}) |
|
|
|
return contamination |
|
|
|
def split_dataset( |
|
df: pd.DataFrame, |
|
seed: int, |
|
split_mode: str |
|
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
|
""" |
|
Perform stratified splitting of the dataset. |
|
""" |
|
|
|
logging.info("Creating stratification labels...") |
|
stratify_labels = df.apply(create_multilabel_stratification_labels, axis=1) |
|
|
|
|
|
logging.info("Oversampling rare classes...") |
|
df_with_oversampling = oversample_rare_classes(df) |
|
|
|
|
|
if split_mode == '3': |
|
|
|
splitter = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed) |
|
train_idx, temp_idx = next(splitter.split(df, stratify_labels)) |
|
|
|
|
|
temp_df = df.iloc[temp_idx] |
|
temp_labels = stratify_labels.iloc[temp_idx] |
|
|
|
splitter = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed) |
|
val_idx, test_idx = next(splitter.split(temp_df, temp_labels)) |
|
|
|
|
|
train_df = df_with_oversampling.iloc[train_idx] |
|
val_df = df.iloc[temp_idx].iloc[val_idx] |
|
test_df = df.iloc[temp_idx].iloc[test_idx] |
|
|
|
else: |
|
splitter = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) |
|
train_idx, val_idx = next(splitter.split(df, stratify_labels)) |
|
|
|
train_df = df_with_oversampling.iloc[train_idx] |
|
val_df = df.iloc[val_idx] |
|
test_df = None |
|
|
|
return train_df, val_df, test_df |
|
|
|
def save_splits( |
|
train_df: pd.DataFrame, |
|
val_df: pd.DataFrame, |
|
test_df: pd.DataFrame, |
|
output_dir: str, |
|
stats: Dict |
|
) -> None: |
|
""" |
|
Save splits and statistics to files. |
|
""" |
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logging.info("Saving splits...") |
|
train_df.to_csv(output_path / 'train.csv', index=False) |
|
val_df.to_csv(output_path / 'val.csv', index=False) |
|
if test_df is not None: |
|
test_df.to_csv(output_path / 'test.csv', index=False) |
|
|
|
|
|
with open(output_path / 'stats.json', 'w', encoding='utf-8') as f: |
|
json.dump(stats, f, indent=2, ensure_ascii=False) |
|
|
|
def compute_text_hash(text: str) -> str: |
|
""" |
|
Compute SHA-256 hash of normalized text. |
|
""" |
|
|
|
normalized = ' '.join(str(text).lower().split()) |
|
return hashlib.sha256(normalized.encode('utf-8')).hexdigest() |
|
|
|
def deduplicate_dataset(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict]: |
|
""" |
|
Remove duplicates using cryptographic hashing while preserving metadata. |
|
""" |
|
logging.info("Starting cryptographic deduplication...") |
|
|
|
|
|
text_column = 'comment_text' if 'comment_text' in df.columns else 'text' |
|
if text_column not in df.columns: |
|
raise ValueError(f"No text column found. Available columns: {df.columns}") |
|
|
|
|
|
logging.info("Computing cryptographic hashes...") |
|
tqdm.pandas(desc="Hashing texts") |
|
df['text_hash'] = df[text_column].progress_apply(compute_text_hash) |
|
|
|
|
|
total_samples = len(df) |
|
duplicate_hashes = df[df.duplicated('text_hash', keep=False)]['text_hash'].unique() |
|
duplicate_groups = { |
|
hash_val: df[df['text_hash'] == hash_val].index.tolist() |
|
for hash_val in duplicate_hashes |
|
} |
|
|
|
|
|
dedup_df = df.drop_duplicates('text_hash', keep='first').copy() |
|
dedup_df = dedup_df.drop('text_hash', axis=1) |
|
|
|
|
|
dedup_stats = { |
|
'total_samples': total_samples, |
|
'unique_samples': len(dedup_df), |
|
'duplicates_removed': total_samples - len(dedup_df), |
|
'duplicate_rate': (total_samples - len(dedup_df)) / total_samples, |
|
'duplicate_groups': { |
|
str(k): { |
|
'count': len(v), |
|
'indices': v |
|
} |
|
for k, v in duplicate_groups.items() |
|
} |
|
} |
|
|
|
logging.info(f"Removed {dedup_stats['duplicates_removed']:,} duplicates " |
|
f"({dedup_stats['duplicate_rate']:.2%} of dataset)") |
|
|
|
return dedup_df, dedup_stats |
|
|
|
def main(): |
|
input_csv = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv' |
|
output_dir = 'dataset/split' |
|
seed = 42 |
|
split_mode = '3' |
|
|
|
start_time = time.time() |
|
|
|
|
|
logging.info(f"Loading dataset from {input_csv}...") |
|
df = pd.read_csv(input_csv) |
|
|
|
|
|
logging.info(f"Available columns: {', '.join(df.columns)}") |
|
|
|
|
|
required_columns = ['lang'] + TOXICITY_COLUMNS |
|
missing_columns = [col for col in required_columns if col not in df.columns] |
|
if missing_columns: |
|
raise ValueError(f"Missing required columns: {missing_columns}") |
|
|
|
|
|
df, dedup_stats = deduplicate_dataset(df) |
|
|
|
|
|
logging.info("Performing stratified split...") |
|
train_df, val_df, test_df = split_dataset(df, seed, split_mode) |
|
|
|
|
|
logging.info("Verifying distributions...") |
|
stats = verify_distributions(df, train_df, val_df, test_df) |
|
|
|
|
|
stats['deduplication'] = dedup_stats |
|
|
|
|
|
logging.info("Checking for contamination...") |
|
contamination = check_contamination(train_df, val_df, test_df) |
|
stats['contamination'] = contamination |
|
|
|
|
|
logging.info(f"Saving splits to {output_dir}...") |
|
save_splits(train_df, val_df, test_df, output_dir, stats) |
|
|
|
elapsed_time = time.time() - start_time |
|
logging.info(f"Done! Elapsed time: {elapsed_time:.2f} seconds") |
|
|
|
|
|
print("\nDeduplication Summary:") |
|
print("-" * 50) |
|
print(f"Original samples: {dedup_stats['total_samples']:,}") |
|
print(f"Unique samples: {dedup_stats['unique_samples']:,}") |
|
print(f"Duplicates removed: {dedup_stats['duplicates_removed']:,} ({dedup_stats['duplicate_rate']:.2%})") |
|
|
|
print("\nSplit Summary:") |
|
print("-" * 50) |
|
print(f"Total samples: {len(df):,}") |
|
print(f"Train samples: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)") |
|
print(f"Validation samples: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)") |
|
if test_df is not None: |
|
print(f"Test samples: {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)") |
|
print("\nDetailed statistics saved to stats.json") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|