Deeptanshuu's picture
Upload folder using huggingface_hub
d187b57 verified
import pandas as pd
import numpy as np
from pathlib import Path
import json
import os
def get_threat_stats(df, lang='pt'):
"""Calculate threat statistics for a given language"""
lang_df = df[df['lang'] == lang]
total = int(len(lang_df)) # Convert to native Python int
threat_count = int(lang_df['threat'].sum()) # Convert to native Python int
return {
'total': total,
'threat_count': threat_count,
'threat_ratio': float(threat_count / total if total > 0 else 0) # Convert to native Python float
}
def fix_pt_threat_distribution(input_dir='dataset/split', output_dir='dataset/balanced'):
"""Fix Portuguese threat class overrepresentation while maintaining dataset balance"""
print("\n=== Fixing Portuguese Threat Distribution ===\n")
# Create output directory
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Load datasets
print("Loading datasets...")
train_df = pd.read_csv(os.path.join(input_dir, 'train.csv'))
val_df = pd.read_csv(os.path.join(input_dir, 'val.csv'))
test_df = pd.read_csv(os.path.join(input_dir, 'test.csv'))
print("\nInitial Portuguese Threat Distribution:")
print("-" * 50)
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
stats = get_threat_stats(df)
print(f"{name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})")
# Calculate target ratio based on train set
target_ratio = float(get_threat_stats(train_df)['threat_ratio']) # Convert to native Python float
print(f"\nTarget threat ratio (from train): {target_ratio:.2%}")
# Fix test set distribution
pt_test = test_df[test_df['lang'] == 'pt']
current_ratio = float(get_threat_stats(test_df)['threat_ratio']) # Convert to native Python float
if current_ratio > target_ratio:
# Calculate how many samples to remove
current_threats = int(pt_test['threat'].sum()) # Convert to native Python int
target_threats = int(len(pt_test) * target_ratio)
samples_to_remove = int(current_threats - target_threats)
print(f"\nRemoving {samples_to_remove} Portuguese threat samples from test set...")
# Identify samples to remove
pt_threat_samples = test_df[
(test_df['lang'] == 'pt') &
(test_df['threat'] > 0)
]
# Randomly select samples to remove
np.random.seed(42) # For reproducibility
remove_idx = np.random.choice(
pt_threat_samples.index,
size=samples_to_remove,
replace=False
).tolist() # Convert to native Python list
# Remove selected samples
test_df = test_df.drop(remove_idx)
# Verify new distribution
new_ratio = float(get_threat_stats(test_df)['threat_ratio']) # Convert to native Python float
print(f"New Portuguese threat ratio: {new_ratio:.2%}")
# Save statistics
stats = {
'original_distribution': {
'train': get_threat_stats(train_df),
'val': get_threat_stats(val_df),
'test': get_threat_stats(test_df)
},
'samples_removed': samples_to_remove,
'target_ratio': target_ratio,
'achieved_ratio': new_ratio
}
with open(os.path.join(output_dir, 'pt_threat_fix_stats.json'), 'w') as f:
json.dump(stats, f, indent=2)
# Save balanced datasets
print("\nSaving balanced datasets...")
train_df.to_csv(os.path.join(output_dir, 'train_balanced.csv'), index=False)
val_df.to_csv(os.path.join(output_dir, 'val_balanced.csv'), index=False)
test_df.to_csv(os.path.join(output_dir, 'test_balanced.csv'), index=False)
print("\nFinal Portuguese Threat Distribution:")
print("-" * 50)
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
stats = get_threat_stats(df)
print(f"{name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})")
else:
print("\nNo fix needed - test set threat ratio is not higher than train")
return train_df, val_df, test_df
def validate_distributions(train_df, val_df, test_df):
"""Validate the threat distributions across all languages"""
print("\nValidating Threat Distributions Across Languages:")
print("-" * 50)
for lang in sorted(train_df['lang'].unique()):
print(f"\n{lang.upper()}:")
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
stats = get_threat_stats(df, lang)
print(f" {name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})")
if __name__ == "__main__":
# Fix Portuguese threat distribution
train_df, val_df, test_df = fix_pt_threat_distribution()
# Validate distributions across all languages
validate_distributions(train_df, val_df, test_df)