|
import pandas as pd |
|
import numpy as np |
|
from pathlib import Path |
|
import json |
|
import os |
|
|
|
def get_threat_stats(df, lang='pt'): |
|
"""Calculate threat statistics for a given language""" |
|
lang_df = df[df['lang'] == lang] |
|
total = int(len(lang_df)) |
|
threat_count = int(lang_df['threat'].sum()) |
|
return { |
|
'total': total, |
|
'threat_count': threat_count, |
|
'threat_ratio': float(threat_count / total if total > 0 else 0) |
|
} |
|
|
|
def fix_pt_threat_distribution(input_dir='dataset/split', output_dir='dataset/balanced'): |
|
"""Fix Portuguese threat class overrepresentation while maintaining dataset balance""" |
|
print("\n=== Fixing Portuguese Threat Distribution ===\n") |
|
|
|
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print("Loading datasets...") |
|
train_df = pd.read_csv(os.path.join(input_dir, 'train.csv')) |
|
val_df = pd.read_csv(os.path.join(input_dir, 'val.csv')) |
|
test_df = pd.read_csv(os.path.join(input_dir, 'test.csv')) |
|
|
|
print("\nInitial Portuguese Threat Distribution:") |
|
print("-" * 50) |
|
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: |
|
stats = get_threat_stats(df) |
|
print(f"{name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})") |
|
|
|
|
|
target_ratio = float(get_threat_stats(train_df)['threat_ratio']) |
|
print(f"\nTarget threat ratio (from train): {target_ratio:.2%}") |
|
|
|
|
|
pt_test = test_df[test_df['lang'] == 'pt'] |
|
current_ratio = float(get_threat_stats(test_df)['threat_ratio']) |
|
|
|
if current_ratio > target_ratio: |
|
|
|
current_threats = int(pt_test['threat'].sum()) |
|
target_threats = int(len(pt_test) * target_ratio) |
|
samples_to_remove = int(current_threats - target_threats) |
|
|
|
print(f"\nRemoving {samples_to_remove} Portuguese threat samples from test set...") |
|
|
|
|
|
pt_threat_samples = test_df[ |
|
(test_df['lang'] == 'pt') & |
|
(test_df['threat'] > 0) |
|
] |
|
|
|
|
|
np.random.seed(42) |
|
remove_idx = np.random.choice( |
|
pt_threat_samples.index, |
|
size=samples_to_remove, |
|
replace=False |
|
).tolist() |
|
|
|
|
|
test_df = test_df.drop(remove_idx) |
|
|
|
|
|
new_ratio = float(get_threat_stats(test_df)['threat_ratio']) |
|
print(f"New Portuguese threat ratio: {new_ratio:.2%}") |
|
|
|
|
|
stats = { |
|
'original_distribution': { |
|
'train': get_threat_stats(train_df), |
|
'val': get_threat_stats(val_df), |
|
'test': get_threat_stats(test_df) |
|
}, |
|
'samples_removed': samples_to_remove, |
|
'target_ratio': target_ratio, |
|
'achieved_ratio': new_ratio |
|
} |
|
|
|
with open(os.path.join(output_dir, 'pt_threat_fix_stats.json'), 'w') as f: |
|
json.dump(stats, f, indent=2) |
|
|
|
|
|
print("\nSaving balanced datasets...") |
|
train_df.to_csv(os.path.join(output_dir, 'train_balanced.csv'), index=False) |
|
val_df.to_csv(os.path.join(output_dir, 'val_balanced.csv'), index=False) |
|
test_df.to_csv(os.path.join(output_dir, 'test_balanced.csv'), index=False) |
|
|
|
print("\nFinal Portuguese Threat Distribution:") |
|
print("-" * 50) |
|
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: |
|
stats = get_threat_stats(df) |
|
print(f"{name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})") |
|
else: |
|
print("\nNo fix needed - test set threat ratio is not higher than train") |
|
|
|
return train_df, val_df, test_df |
|
|
|
def validate_distributions(train_df, val_df, test_df): |
|
"""Validate the threat distributions across all languages""" |
|
print("\nValidating Threat Distributions Across Languages:") |
|
print("-" * 50) |
|
|
|
for lang in sorted(train_df['lang'].unique()): |
|
print(f"\n{lang.upper()}:") |
|
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: |
|
stats = get_threat_stats(df, lang) |
|
print(f" {name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})") |
|
|
|
if __name__ == "__main__": |
|
|
|
train_df, val_df, test_df = fix_pt_threat_distribution() |
|
|
|
|
|
validate_distributions(train_df, val_df, test_df) |