File size: 8,602 Bytes
d187b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import torch

# Configure CPU and thread settings FIRST, before any other imports
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
os.environ['TF_CPU_ENABLE_AVX2'] = '1'
os.environ['TF_CPU_ENABLE_AVX512F'] = '1'
os.environ['TF_CPU_ENABLE_AVX512_VNNI'] = '1'
os.environ['TF_CPU_ENABLE_FMA'] = '1'
os.environ['MKL_NUM_THREADS'] = '80'
os.environ['OMP_NUM_THREADS'] = '80'

# Set PyTorch thread configurations once
torch.set_num_threads(80)
torch.set_num_interop_threads(10)

# Now import everything else
import pandas as pd
import numpy as np
from pathlib import Path
import logging
from datetime import datetime
import sys
from toxic_augment import ToxicAugmenter
import json

# Configure logging
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)

timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
log_file = log_dir / f"balance_english_{timestamp}.log"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler(log_file)
    ]
)

logger = logging.getLogger(__name__)

def analyze_label_distribution(df, lang='en'):
    """Analyze label distribution for a specific language"""
    labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    lang_df = df[df['lang'] == lang]
    total = len(lang_df)
    
    if total == 0:
        logger.warning(f"No samples found for language {lang.upper()}.")
        return {}
    
    logger.info(f"\nLabel Distribution for {lang.upper()}:")
    logger.info("-" * 50)
    dist = {}
    for label in labels:
        count = lang_df[label].sum()
        percentage = (count / total) * 100
        dist[label] = {'count': int(count), 'percentage': percentage}
        logger.info(f"{label}: {count:,} ({percentage:.2f}%)")
    return dist

def analyze_language_distribution(df):
    """Analyze current language distribution"""
    lang_dist = df['lang'].value_counts()
    logger.info("\nCurrent Language Distribution:")
    logger.info("-" * 50)
    for lang, count in lang_dist.items():
        logger.info(f"{lang}: {count:,} comments ({count/len(df)*100:.2f}%)")
    return lang_dist

def calculate_required_samples(df):
    """Calculate how many English samples we need to generate"""
    lang_counts = df['lang'].value_counts()
    target_count = lang_counts.max()  # Use the largest language count as target
    en_count = lang_counts.get('en', 0)
    required_samples = target_count - en_count
    
    logger.info(f"\nTarget count per language: {target_count:,}")
    logger.info(f"Current English count: {en_count:,}")
    logger.info(f"Required additional English samples: {required_samples:,}")
    
    return required_samples

def generate_balanced_samples(df, required_samples):
    """Generate samples maintaining original class distribution ratios"""
    logger.info("\nGenerating balanced samples...")
    
    # Get English samples
    en_df = df[df['lang'] == 'en']
    labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    
    # Calculate target counts for each label
    target_counts = {}
    for label in labels:
        count = en_df[label].sum()
        ratio = count / len(en_df)
        target_count = int(ratio * required_samples)
        target_counts[label] = target_count
        logger.info(f"Target count for {label}: {target_count:,}")
    
    augmented_samples = []
    augmenter = ToxicAugmenter()
    total_generated = 0
    
    # Generate samples for each label
    for label, target_count in target_counts.items():
        if target_count == 0:
            continue
            
        logger.info(f"\nGenerating {target_count:,} samples for {label}")
        
        # Get seed texts with this label
        seed_texts = en_df[en_df[label] == 1]['comment_text'].tolist()
        
        if not seed_texts:
            logger.warning(f"No seed texts found for {label}, skipping...")
            continue
        
        # Generate samples with 5-minute timeout
        new_samples = augmenter.augment_dataset(
            target_samples=target_count,
            label=label,  # Using single label instead of label_combo
            seed_texts=seed_texts,
            timeout_minutes=5
        )
        
        if new_samples is not None and not new_samples.empty:
            augmented_samples.append(new_samples)
            total_generated += len(new_samples)
            
            # Log progress
            logger.info(f"✓ Generated {len(new_samples):,} samples")
            logger.info(f"Progress: {total_generated:,}/{required_samples:,}")
            
            # Check if we have reached our global required samples
            if total_generated >= required_samples:
                logger.info("Reached required sample count, stopping generation")
                break
    
    # Combine all generated samples
    if augmented_samples:
        augmented_df = pd.concat(augmented_samples, ignore_index=True)
        augmented_df['lang'] = 'en'
        
        # Ensure we don't exceed the required sample count
        if len(augmented_df) > required_samples:
            logger.info(f"Trimming excess samples from {len(augmented_df):,} to {required_samples:,}")
            augmented_df = augmented_df.head(required_samples)
        
        # Log final class distribution
        logger.info("\nFinal class distribution in generated samples:")
        for label in labels:
            count = augmented_df[label].sum()
            percentage = (count / len(augmented_df)) * 100
            logger.info(f"{label}: {count:,} ({percentage:.2f}%)")
        
        # Also log clean samples
        clean_count = len(augmented_df[augmented_df[labels].sum(axis=1) == 0])
        clean_percentage = (clean_count / len(augmented_df)) * 100
        logger.info(f"Clean samples: {clean_count:,} ({clean_percentage:.2f}%)")
        
        return augmented_df
    else:
        raise Exception("Failed to generate any valid samples")

def balance_english_data():
    """Main function to balance English data with other languages"""
    try:
        # Load dataset
        input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv'
        logger.info(f"Loading dataset from {input_file}")
        df = pd.read_csv(input_file)
        
        # Analyze current distribution
        logger.info("\nAnalyzing current distribution...")
        initial_dist = analyze_language_distribution(df)
        initial_label_dist = analyze_label_distribution(df, 'en')
        
        # Calculate required samples
        required_samples = calculate_required_samples(df)
        
        if required_samples <= 0:
            logger.info("English data is already balanced. No augmentation needed.")
            return
        
        # Generate balanced samples
        augmented_df = generate_balanced_samples(df, required_samples)
        
        # Merge with original dataset
        logger.info("\nMerging datasets...")
        output_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_BALANCED.csv'
        
        # Combine datasets
        combined_df = pd.concat([df, augmented_df], ignore_index=True)
        
        # Save balanced dataset
        combined_df.to_csv(output_file, index=False)
        logger.info(f"\nSaved balanced dataset to {output_file}")
        
        # Final distribution check
        logger.info("\nFinal distribution after balancing:")
        final_dist = analyze_language_distribution(combined_df)
        final_label_dist = analyze_label_distribution(combined_df, 'en')
        
        # Save distribution statistics
        stats = {
            'timestamp': timestamp,
            'initial_distribution': {
                'languages': initial_dist.to_dict(),
                'english_labels': initial_label_dist
            },
            'final_distribution': {
                'languages': final_dist.to_dict(),
                'english_labels': final_label_dist
            },
            'samples_generated': len(augmented_df),
            'total_samples': len(combined_df)
        }
        
        stats_file = f'logs/balance_stats_{timestamp}.json'
        with open(stats_file, 'w') as f:
            json.dump(stats, f, indent=2)
        logger.info(f"\nSaved balancing statistics to {stats_file}")
        
    except Exception as e:
        logger.error(f"Error during balancing: {str(e)}")
        raise

def main():
    balance_english_data()

if __name__ == "__main__":
    logger.info("Starting English data balancing process...")
    main()