diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..c3212fd0be32dba5ac5442412b529a09694bfe0c --- /dev/null +++ b/.dockerignore @@ -0,0 +1,48 @@ +# Git +.git +.gitignore + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Project specific +dataset/ +weights/ +wandb/ +*.pt +*.pth +*.ckpt + +# Logs +*.log +logs/ \ No newline at end of file diff --git a/.env.template b/.env.template new file mode 100644 index 0000000000000000000000000000000000000000..ba13f716197f3039622b832f76925e3c5332e38d --- /dev/null +++ b/.env.template @@ -0,0 +1,9 @@ +# Weights & Biases API Key +WANDB_API_KEY= + +# Model Configuration +BATCH_SIZE=16 +GRAD_ACCUM_STEPS=4 +EPOCHS=5 +LEARNING_RATE=2e-5 +MODEL_NAME=xlm-roberta-large \ No newline at end of file diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a022eb1ad8ecf5326f10cdad00a729e9a380e50f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,27 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +dataset/raw/MULTILINGUAL_TOXIC_DATASET_360K_7LANG.csv filter=lfs diff=lfs merge=lfs -text +dataset/raw/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_binary.csv filter=lfs diff=lfs merge=lfs -text +dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv filter=lfs diff=lfs merge=lfs -text +dataset/split/train.csv filter=lfs diff=lfs merge=lfs -text +dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_0.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_1.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_2.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_3.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_4.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_5.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/calibration_6.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/plots/class_calibration.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250208_161149/predictions.npz filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_all_classes.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_by_language.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_identity_hate.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_insult.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_obscene.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_severe_toxic.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_threat.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/plots/roc_toxic.png filter=lfs diff=lfs merge=lfs -text +evaluation_results/eval_20250401_143401/predictions.npz filter=lfs diff=lfs merge=lfs -text +images/class_distribution.png filter=lfs diff=lfs merge=lfs -text +images/language_distribution.png filter=lfs diff=lfs merge=lfs -text +images/toxicity_by_language.png filter=lfs diff=lfs merge=lfs -text +images/toxicity_correlation.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9a838cafc7189f5c6e8eff280530258bc600fb60 --- /dev/null +++ b/.gitignore @@ -0,0 +1,83 @@ +# Python cache files +__pycache__/ +*.py[cod] + +# Virtual environment +venv/ +ENV/ +env/ +env.bak/ +venv.bak/ + +# Gradio +.gradio/* + +# Weights and Biases +weights/* +dataset/* +cache/* +wandb/* + +# IDE and editor files +.idea/ +.vscode/ +*.swp +*.swo + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Pytest +.cache/ +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# pyre type checker +.pyre/ + +# C extensions +*.so + +# Backup files +*~ +*.bak +*.tmp + +#Logging +*.log +logs/ + +*.csv \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..020166acdbf578bf3be5df3b0beaae8294bea368 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +# Use CUDA-enabled PyTorch base image +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements file +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy project files +COPY . . + +# Create directories for data and models +RUN mkdir -p dataset/final_balanced weights + +# Set environment variables +ENV PYTHONPATH=/app +ENV WANDB_API_KEY="" + +# Default command to run training +CMD ["python", "model/train.py"] \ No newline at end of file diff --git a/analysis/analysis.txt b/analysis/analysis.txt new file mode 100644 index 0000000000000000000000000000000000000000..8896b4a65be1dacca3336ccc955d8ed072d374bc --- /dev/null +++ b/analysis/analysis.txt @@ -0,0 +1,264 @@ +(venv) PS V:\Deeptanshu Lal\PROJECTS\Toxic Comment Classification> python .\analysis\analyze_lang_distribution.py +Reading dataset... + +Dataset Overview: +-------------------------------------------------- +Total number of comments: 361,228 +Number of languages: 7 + +Language Distribution: +-------------------------------------------------- +ru: 52,632 comments (14.57%) +tr: 52,558 comments (14.55%) +pt: 52,440 comments (14.52%) +es: 52,412 comments (14.51%) +fr: 52,368 comments (14.50%) +it: 52,340 comments (14.49%) +en: 46,478 comments (12.87%) + +Class Distribution by Language: +-------------------------------------------------- + +RU (Total: 52,632 comments) +0 toxic classes: 26,316 (50.00%) +1 toxic classes: 7,688 (14.61%) +2 toxic classes: 8,010 (15.22%) +3 toxic classes: 7,103 (13.50%) +4 toxic classes: 2,740 (5.21%) +5 toxic classes: 706 (1.34%) +6 toxic classes: 69 (0.13%) + +TR (Total: 52,558 comments) +0 toxic classes: 26,279 (50.00%) +1 toxic classes: 7,677 (14.61%) +2 toxic classes: 8,004 (15.23%) +3 toxic classes: 7,088 (13.49%) +4 toxic classes: 2,736 (5.21%) +5 toxic classes: 705 (1.34%) +6 toxic classes: 69 (0.13%) + +PT (Total: 52,440 comments) +0 toxic classes: 26,220 (50.00%) +1 toxic classes: 7,668 (14.62%) +2 toxic classes: 7,977 (15.21%) +3 toxic classes: 7,071 (13.48%) +4 toxic classes: 2,732 (5.21%) +5 toxic classes: 703 (1.34%) +6 toxic classes: 69 (0.13%) + +ES (Total: 52,412 comments) +0 toxic classes: 26,206 (50.00%) +1 toxic classes: 7,647 (14.59%) +2 toxic classes: 7,982 (15.23%) +3 toxic classes: 7,069 (13.49%) +4 toxic classes: 2,737 (5.22%) +5 toxic classes: 702 (1.34%) +6 toxic classes: 69 (0.13%) + +FR (Total: 52,368 comments) +0 toxic classes: 26,184 (50.00%) +1 toxic classes: 7,626 (14.56%) +2 toxic classes: 7,990 (15.26%) +3 toxic classes: 7,066 (13.49%) +4 toxic classes: 2,728 (5.21%) +5 toxic classes: 705 (1.35%) +6 toxic classes: 69 (0.13%) + +IT (Total: 52,340 comments) +0 toxic classes: 26,170 (50.00%) +1 toxic classes: 7,652 (14.62%) +2 toxic classes: 7,967 (15.22%) +3 toxic classes: 7,057 (13.48%) +4 toxic classes: 2,722 (5.20%) +5 toxic classes: 703 (1.34%) +6 toxic classes: 69 (0.13%) + +EN (Total: 46,478 comments) +0 toxic classes: 22,989 (49.46%) +1 toxic classes: 8,499 (18.29%) +2 toxic classes: 5,604 (12.06%) +3 toxic classes: 6,391 (13.75%) +4 toxic classes: 2,395 (5.15%) +5 toxic classes: 553 (1.19%) +6 toxic classes: 47 (0.10%) + +Detailed Toxicity Analysis by Language: +-------------------------------------------------- + +RU (Total: 52,632 comments) +- Toxic: + Count: 25,954 (49.31%) + 95% CI: [48.89%, 49.74%] +- Severe Toxic: + Count: 2,441 (4.64%) + 95% CI: [4.46%, 4.82%] +- Obscene: + Count: 12,432 (23.62%) + 95% CI: [23.26%, 23.98%] +- Threat: + Count: 1,075 (2.04%) + 95% CI: [1.92%, 2.16%] +- Insult: + Count: 15,207 (28.89%) + 95% CI: [28.51%, 29.28%] +- Identity Hate: + Count: 2,812 (5.34%) + 95% CI: [5.15%, 5.53%] + +TR (Total: 52,558 comments) +- Toxic: + Count: 25,908 (49.29%) + 95% CI: [48.87%, 49.72%] +- Severe Toxic: + Count: 2,439 (4.64%) + 95% CI: [4.46%, 4.82%] +- Obscene: + Count: 12,411 (23.61%) + 95% CI: [23.25%, 23.98%] +- Threat: + Count: 1,077 (2.05%) + 95% CI: [1.93%, 2.17%] +- Insult: + Count: 15,170 (28.86%) + 95% CI: [28.48%, 29.25%] +- Identity Hate: + Count: 2,827 (5.38%) + 95% CI: [5.19%, 5.57%] + +PT (Total: 52,440 comments) +- Toxic: + Count: 25,841 (49.28%) + 95% CI: [48.85%, 49.71%] +- Severe Toxic: + Count: 2,432 (4.64%) + 95% CI: [4.46%, 4.82%] +- Obscene: + Count: 12,395 (23.64%) + 95% CI: [23.27%, 24.00%] +- Threat: + Count: 1,080 (2.06%) + 95% CI: [1.94%, 2.18%] +- Insult: + Count: 15,143 (28.88%) + 95% CI: [28.49%, 29.26%] +- Identity Hate: + Count: 2,801 (5.34%) + 95% CI: [5.15%, 5.53%] + +ES (Total: 52,412 comments) +- Toxic: + Count: 25,874 (49.37%) + 95% CI: [48.94%, 49.79%] +- Severe Toxic: + Count: 2,432 (4.64%) + 95% CI: [4.46%, 4.82%] +- Obscene: + Count: 12,388 (23.64%) + 95% CI: [23.27%, 24.00%] +- Threat: + Count: 1,073 (2.05%) + 95% CI: [1.93%, 2.17%] +- Insult: + Count: 15,140 (28.89%) + 95% CI: [28.50%, 29.27%] +- Identity Hate: + Count: 2,783 (5.31%) + 95% CI: [5.12%, 5.50%] + +FR (Total: 52,368 comments) +- Toxic: + Count: 25,877 (49.41%) + 95% CI: [48.99%, 49.84%] +- Severe Toxic: + Count: 2,428 (4.64%) + 95% CI: [4.46%, 4.82%] +- Obscene: + Count: 12,379 (23.64%) + 95% CI: [23.27%, 24.00%] +- Threat: + Count: 1,066 (2.04%) + 95% CI: [1.91%, 2.16%] +- Insult: + Count: 15,131 (28.89%) + 95% CI: [28.51%, 29.28%] +- Identity Hate: + Count: 2,774 (5.30%) + 95% CI: [5.11%, 5.49%] + +IT (Total: 52,340 comments) +- Toxic: + Count: 25,827 (49.34%) + 95% CI: [48.92%, 49.77%] +- Severe Toxic: + Count: 2,429 (4.64%) + 95% CI: [4.46%, 4.82%] +- Obscene: + Count: 12,341 (23.58%) + 95% CI: [23.21%, 23.94%] +- Threat: + Count: 1,077 (2.06%) + 95% CI: [1.94%, 2.18%] +- Insult: + Count: 15,118 (28.88%) + 95% CI: [28.50%, 29.27%] +- Identity Hate: + Count: 2,782 (5.32%) + 95% CI: [5.12%, 5.51%] + +EN (Total: 46,478 comments) +- Toxic: + Count: 22,343 (48.07%) + 95% CI: [47.62%, 48.53%] +- Severe Toxic: + Count: 1,986 (4.27%) + 95% CI: [4.09%, 4.46%] +- Obscene: + Count: 12,356 (26.58%) + 95% CI: [26.18%, 26.99%] +- Threat: + Count: 1,204 (2.59%) + 95% CI: [2.45%, 2.73%] +- Insult: + Count: 11,475 (24.69%) + 95% CI: [24.30%, 25.08%] +- Identity Hate: + Count: 2,143 (4.61%) + 95% CI: [4.42%, 4.80%] + +Statistical Analysis: +-------------------------------------------------- + +Chi-square test for number of toxic classes by language: +Chi-square statistic: 654.28 +p-value: 0.0000000000 +Significant at α=0.05: Yes + +Chi-square test for Toxic: +Chi-square statistic: 26.10 +p-value: 0.0002136602 +Significant at α=0.05: Yes + +Chi-square test for Severe Toxic: +Chi-square statistic: 12.38 +p-value: 0.0540052211 +Significant at α=0.05: No + +Chi-square test for Obscene: +Chi-square statistic: 195.12 +p-value: 0.0000000000 +Significant at α=0.05: Yes + +Chi-square test for Threat: +Chi-square statistic: 57.45 +p-value: 0.0000000001 +Significant at α=0.05: Yes + +Chi-square test for Insult: +Chi-square statistic: 350.72 +p-value: 0.0000000000 +Significant at α=0.05: Yes + +Chi-square test for Identity Hate: +Chi-square statistic: 42.77 +p-value: 0.0000001295 +Significant at α=0.05: Yes \ No newline at end of file diff --git a/analysis/analyze_lang_distribution.py b/analysis/analyze_lang_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1f042c6bdde79c93d1bfc25ec059c3a72b71d1 --- /dev/null +++ b/analysis/analyze_lang_distribution.py @@ -0,0 +1,336 @@ +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np +from scipy import stats +import os + +def set_style(): + """Set the style for all plots""" + # Use a basic style instead of seaborn + plt.style.use('default') + + # Custom style settings + plt.rcParams['figure.figsize'] = (12, 6) + plt.rcParams['font.size'] = 10 + plt.rcParams['axes.titlesize'] = 14 + plt.rcParams['axes.labelsize'] = 12 + plt.rcParams['axes.grid'] = True + plt.rcParams['grid.alpha'] = 0.3 + + # Custom color palette + colors = ['#FF9999', '#66B2FF', '#99FF99', '#FFCC99', '#FF99CC', '#99FFCC', '#FFB366'] + return colors + +def create_language_distribution_plot(df, lang_dist, lang_percent, colors, image_dir): + """Create and save language distribution plot""" + plt.figure(figsize=(14, 8)) + + # Create bar positions + x = np.arange(len(lang_dist)) + + # Create bars with language names as x-ticks + bars = plt.bar(x, lang_dist.values, color=colors) + plt.title('Language Distribution in Multilingual Toxic Comment Dataset', pad=20) + plt.xlabel('Language', labelpad=10) + plt.ylabel('Number of Comments', labelpad=10) + + # Set x-ticks to language names + plt.xticks(x, lang_dist.index, rotation=45) + + # Add value labels on top of each bar with increased spacing + for i, bar in enumerate(bars): + height = bar.get_height() + plt.text(bar.get_x() + bar.get_width()/2., height + (max(lang_dist.values) * 0.01), + f'{int(height):,}\n({lang_percent.values[i]:.1f}%)', + ha='center', va='bottom', fontsize=10) + + # Add some padding to the top of the plot + plt.margins(y=0.2) + + plt.tight_layout() + plt.savefig(os.path.join(image_dir, 'language_distribution.png'), dpi=300, bbox_inches='tight') + plt.close() + +def create_toxicity_heatmap(df, toxicity_cols, image_dir): + """Create and save toxicity correlation heatmap""" + plt.figure(figsize=(12, 10)) + + # Calculate correlation and sort + correlation = df[toxicity_cols].corr() + + # Sort correlation matrix by mean correlation value + mean_corr = correlation.mean() + sorted_cols = mean_corr.sort_values(ascending=False).index + correlation = correlation.loc[sorted_cols, sorted_cols] + + # Create heatmap with better styling + im = plt.imshow(correlation, cmap='RdYlBu_r', aspect='equal', vmin=0, vmax=1) + plt.colorbar(im, label='Correlation Coefficient') + + # Add text annotations with conditional formatting + for i in range(len(correlation)): + for j in range(len(correlation)): + corr_value = correlation.iloc[i, j] + # Choose text color based on background + text_color = 'white' if abs(corr_value) > 0.7 else 'black' + # Make diagonal elements bold + fontweight = 'bold' if i == j else 'normal' + plt.text(j, i, f'{corr_value:.2f}', + ha='center', va='center', + color=text_color, + fontweight=fontweight, + fontsize=10) + + # Improve title and labels + plt.title('Correlation between Different Types of Toxicity\n(Sorted by Average Correlation)', + pad=20, fontsize=14) + + # Format axis labels + formatted_labels = [col.replace('_', ' ').title() for col in correlation.columns] + plt.xticks(range(len(formatted_labels)), formatted_labels, rotation=45, ha='right') + plt.yticks(range(len(formatted_labels)), formatted_labels) + + # Add gridlines + plt.grid(False) + + # Adjust layout + plt.tight_layout() + plt.savefig(os.path.join(image_dir, 'toxicity_correlation.png'), dpi=300, bbox_inches='tight') + plt.close() + +def create_toxicity_by_language_plot(df, lang_dist, toxicity_cols, colors, image_dir): + """Create and save toxicity distribution by language plot""" + plt.figure(figsize=(15, 8)) + + x = np.arange(len(lang_dist.index)) + width = 0.15 + multiplier = 0 + + for attribute, color in zip(toxicity_cols, colors): + # Calculate percentage of toxic comments (any value > 0) + attribute_means = [(df[df['lang'] == lang][attribute] > 0).mean() * 100 + for lang in lang_dist.index] + + offset = width * multiplier + rects = plt.bar(x + offset, attribute_means, width, + label=attribute.replace('_', ' ').title(), + color=color, alpha=0.8) + + # Add value labels on the bars + for rect in rects: + height = rect.get_height() + plt.text(rect.get_x() + rect.get_width()/2., height, + f'{height:.1f}%', ha='center', va='bottom', fontsize=8) + + multiplier += 1 + + plt.xlabel('Language') + plt.ylabel('Percentage of Toxic Comments (%)') + plt.title('Distribution of Toxicity Types by Language') + plt.xticks(x + width * 2.5, lang_dist.index, rotation=45) + plt.legend(loc='upper right', bbox_to_anchor=(1, 1)) + plt.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(os.path.join(image_dir, 'toxicity_by_language.png'), dpi=300, bbox_inches='tight') + plt.close() + +def create_class_distribution_plot(df, lang_dist, image_dir): + """Create and save class distribution across languages plot""" + plt.figure(figsize=(16, 10)) + + # Define toxicity columns and their display names + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + display_names = [col.replace('_', ' ').title() for col in toxicity_cols] + + # Calculate class distribution for each language + class_dist = {} + non_toxic_dist = {} # Store non-toxic percentages + for lang in lang_dist.index: + lang_df = df[df['lang'] == lang] + total = len(lang_df) + + # Create a binary matrix of toxicity flags + toxic_matrix = lang_df[toxicity_cols].astype(bool) + + # Calculate non-toxic percentage (comments with no toxic flags) + non_toxic_mask = ~toxic_matrix.any(axis=1) + non_toxic_percent = (non_toxic_mask.sum() / total) * 100 + non_toxic_dist[lang] = non_toxic_percent + + # Calculate percentages for each toxicity type + class_dist[lang] = [(toxic_matrix[col].sum() / total) * 100 for col in toxicity_cols] + + # Create stacked bar chart + x = np.arange(len(lang_dist.index)) + + # Use a color scheme with an additional color for non-toxic + colors = plt.cm.Set3(np.linspace(0, 1, len(toxicity_cols) + 1)) + + # First, plot non-toxic comments + non_toxic_values = [non_toxic_dist[lang] for lang in lang_dist.index] + non_toxic_bar = plt.bar(x, non_toxic_values, label='Non-Toxic', color=colors[0], alpha=0.9) + + # Add percentage labels for non-toxic + for j, v in enumerate(non_toxic_values): + if v > 1: # Show all values above 1% + plt.text(x[j], v/2, f'{v:.1f}%', + ha='center', va='center', + color='black', + fontweight='bold', + fontsize=9) + + # Initialize bottom array with non-toxic values + bottom = np.array(non_toxic_values) + + # Then plot toxic categories + bars = [non_toxic_bar] + for i, (col, display_name) in enumerate(zip(toxicity_cols, display_names)): + values = [class_dist[lang][i] for lang in lang_dist.index] + bar = plt.bar(x, values, bottom=bottom, label=display_name, color=colors[i+1], alpha=0.9) + bars.append(bar) + + # Add percentage labels for all values > 1% + for j, v in enumerate(values): + if v > 1: # Show all values above 1% + center = bottom[j] + v/2 + text_color = 'black' if v > 10 else 'black' + plt.text(x[j], center, f'{v:.1f}%', + ha='center', va='center', + color=text_color, + fontweight='bold', + fontsize=9) + bottom = bottom + np.array(values) # Update bottom array correctly + + plt.xlabel('Language', labelpad=10, fontsize=12) + plt.ylabel('Percentage of Comments', labelpad=10, fontsize=12) + plt.title('Distribution of Non-Toxic and Toxic Comments by Language', pad=20, fontsize=14) + plt.xticks(x, lang_dist.index, rotation=45, fontsize=10) + + # Adjust legend + plt.legend(title='Comment Types', + bbox_to_anchor=(1.15, 1), + loc='upper left', + fontsize=10, + title_fontsize=12) + + # Add grid for better readability + plt.grid(True, axis='y', alpha=0.3) + + # Adjust layout to prevent label cutoff + plt.margins(y=0.1) + plt.tight_layout() + plt.savefig(os.path.join(image_dir, 'class_distribution.png'), dpi=300, bbox_inches='tight') + plt.close() + +def analyze_language_distribution(): + """Analyze language distribution and toxicity patterns in the dataset""" + # Create images directory if it doesn't exist + image_dir = 'images' + os.makedirs(image_dir, exist_ok=True) + + # Set style and get color palette + colors = set_style() + + # Read the dataset + print("Reading dataset...") + input_file = 'dataset/split/train.csv' + df = pd.read_csv(input_file) + + # Get language distribution + lang_dist = df['lang'].value_counts() + lang_percent = df['lang'].value_counts(normalize=True) * 100 + + # Print basic statistics + print("\nDataset Overview:") + print("-" * 50) + print("Input file: ", input_file) + print(f"Total number of comments: {len(df):,}") + print(f"Number of languages: {df['lang'].nunique()}") + + print("\nLanguage Distribution:") + print("-" * 50) + for lang, count in lang_dist.items(): + print(f"{lang}: {count:,} comments ({lang_percent[lang]:.2f}%)") + + # Create language distribution plot + create_language_distribution_plot(df, lang_dist, lang_percent, colors, image_dir) + + # Analyze toxicity + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Create correlation heatmap + create_toxicity_heatmap(df, toxicity_cols, image_dir) + + # Create toxicity by language plot + create_toxicity_by_language_plot(df, lang_dist, toxicity_cols, colors, image_dir) + + # Create class distribution plot + create_class_distribution_plot(df, lang_dist, image_dir) + + # Print class distribution statistics + print("\nClass Distribution by Language:") + print("-" * 50) + + for lang in lang_dist.index: + lang_df = df[df['lang'] == lang] + total = len(lang_df) + + print(f"\n{lang.upper()} (Total: {total:,} comments)") + + # Count comments by number of toxic classes + toxic_counts = lang_df[toxicity_cols].astype(bool).sum(axis=1) + class_dist = toxic_counts.value_counts().sort_index() + + for n_classes, count in class_dist.items(): + percentage = (count / total) * 100 + print(f"{n_classes} toxic classes: {count:,} ({percentage:.2f}%)") + + # Detailed toxicity analysis by language + print("\nDetailed Toxicity Analysis by Language:") + print("-" * 50) + + for lang in lang_dist.index: + lang_df = df[df['lang'] == lang] + print(f"\n{lang.upper()} (Total: {len(lang_df):,} comments)") + + # Calculate toxicity statistics + for col in toxicity_cols: + toxic_count = (lang_df[col] > 0).sum() + toxic_percent = (toxic_count / len(lang_df)) * 100 + + # Calculate confidence interval + ci = stats.norm.interval(0.95, + loc=toxic_percent/100, + scale=np.sqrt((toxic_percent/100 * (1-toxic_percent/100)) / len(lang_df))) + ci_lower, ci_upper = ci[0] * 100, ci[1] * 100 + + print(f"- {col.replace('_', ' ').title()}:") + print(f" Count: {toxic_count:,} ({toxic_percent:.2f}%)") + print(f" 95% CI: [{ci_lower:.2f}%, {ci_upper:.2f}%]") + + # Statistical tests + print("\nStatistical Analysis:") + print("-" * 50) + + # Chi-square test for independence between language and number of toxic classes + toxic_class_counts = pd.crosstab(df['lang'], df[toxicity_cols].astype(bool).sum(axis=1)) + chi2, p_value, _, _ = stats.chi2_contingency(toxic_class_counts) + print("\nChi-square test for number of toxic classes by language:") + print(f"Chi-square statistic: {chi2:.2f}") + print(f"p-value: {p_value:.10f}") + print(f"Significant at α=0.05: {'Yes' if p_value < 0.05 else 'No'}") + + # Chi-square test for each toxicity type + for col in toxicity_cols: + binary_col = (df[col] > 0).astype(int) + contingency_table = pd.crosstab(df['lang'], binary_col) + chi2, p_value, _, _ = stats.chi2_contingency(contingency_table) + print(f"\nChi-square test for {col.replace('_', ' ').title()}:") + print(f"Chi-square statistic: {chi2:.2f}") + print(f"p-value: {p_value:.10f}") + print(f"Significant at α=0.05: {'Yes' if p_value < 0.05 else 'No'}") + +if __name__ == "__main__": + analyze_language_distribution() \ No newline at end of file diff --git a/analysis/compute_class_weights.py b/analysis/compute_class_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..2a3d0b64573d4c5e0c73ac22b4717584b00ea98c --- /dev/null +++ b/analysis/compute_class_weights.py @@ -0,0 +1,499 @@ +import numpy as np +import pandas as pd +import json +from typing import Dict, List +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + +def validate_parameters(params: Dict) -> Dict: + """ + Validate weight calculation parameters to prevent dangerous combinations. + Includes validation for focal loss parameters. + """ + # Check for dangerous weight scaling + if params['boost_factor'] * params['max_weight'] > 30: + raise ValueError(f"Dangerous weight scaling detected: boost_factor * max_weight = {params['boost_factor'] * params['max_weight']}") + + # Validate focal loss parameters + if not 0 < params['gamma'] <= 5.0: + raise ValueError(f"Invalid gamma value: {params['gamma']}. Must be in (0, 5.0]") + + if not 0 < params['alpha'] < 1: + raise ValueError(f"Invalid alpha value: {params['alpha']}. Must be in (0, 1)") + + # Check for potentially unstable combinations + if params['gamma'] > 3.0 and params['boost_factor'] > 1.5: + logging.warning(f"Potentially unstable combination: high gamma ({params['gamma']}) with high boost factor ({params['boost_factor']})") + + if params['alpha'] > 0.4 and params['boost_factor'] > 1.5: + logging.warning(f"Potentially unstable combination: high alpha ({params['alpha']}) with high boost factor ({params['boost_factor']})") + + return params + +def calculate_safe_weights( + support_0: int, + support_1: int, + max_weight: float = 15.0, + min_weight: float = 0.5, + gamma: float = 2.0, + alpha: float = 0.25, + boost_factor: float = 1.0, + num_classes: int = 6, + lang: str = None, + toxicity_type: str = None +) -> Dict[str, float]: + """ + Calculate class weights with focal loss and adaptive scaling. + Uses focal loss components for better handling of imbalanced classes + while preserving language-specific adjustments. + + Args: + support_0: Number of negative samples + support_1: Number of positive samples + max_weight: Maximum allowed weight + min_weight: Minimum allowed weight + gamma: Focal loss gamma parameter for down-weighting easy examples + alpha: Focal loss alpha parameter for balancing positive/negative classes + boost_factor: Optional boost for specific classes + num_classes: Number of toxicity classes (default=6) + lang: Language code for language-specific constraints + toxicity_type: Type of toxicity for class-specific constraints + """ + # Input validation with detailed error messages + if support_0 < 0 or support_1 < 0: + raise ValueError(f"Negative sample counts: support_0={support_0}, support_1={support_1}") + + eps = 1e-7 # Small epsilon for numerical stability + total = support_0 + support_1 + eps + + # Handle empty dataset case + if total <= eps: + logging.warning(f"Empty dataset for {toxicity_type} in {lang}") + return { + "0": 1.0, + "1": 1.0, + "support_0": support_0, + "support_1": support_1, + "raw_weight_1": 1.0, + "calculation_metadata": { + "formula": "default_weights_empty_dataset", + "constraints_applied": ["empty_dataset_fallback"] + } + } + + # Handle zero support cases safely + if support_1 == 0: + logging.warning(f"No positive samples for {toxicity_type} in {lang}") + return { + "0": 1.0, + "1": max_weight, + "support_0": support_0, + "support_1": support_1, + "raw_weight_1": max_weight, + "calculation_metadata": { + "formula": "max_weight_no_positives", + "constraints_applied": ["no_positives_fallback"] + } + } + + # Determine effective maximum weight based on class and language + if lang == 'en' and toxicity_type == 'threat': + effective_max = min(max_weight, 15.0) # Absolute cap for EN threat + elif toxicity_type == 'identity_hate': + effective_max = min(max_weight, 10.0) # Cap for identity hate + else: + effective_max = max_weight + + try: + # Calculate class frequencies + freq_1 = support_1 / total + freq_0 = support_0 / total + + # Focal loss components + pt = freq_1 + eps # Probability of target class + modulating_factor = (1 - pt) ** gamma + balanced_alpha = alpha / (alpha + (1 - alpha) * (1 - pt)) + + # Base weight calculation with focal loss + raw_weight_1 = balanced_alpha * modulating_factor / (pt + eps) + + # Apply adaptive scaling for severe classes + if toxicity_type in ['threat', 'identity_hate']: + severity_factor = (1 + np.log1p(total) / np.log1p(support_1)) / 2 + raw_weight_1 *= severity_factor + + # Apply boost factor + raw_weight_1 *= boost_factor + + # Detect potential numerical instability + if not np.isfinite(raw_weight_1): + logging.error(f"Numerical instability detected for {toxicity_type} in {lang}") + raw_weight_1 = effective_max + + except Exception as e: + logging.error(f"Weight calculation error: {str(e)}") + raw_weight_1 = effective_max + + # Apply safety limits with effective maximum + weight_1 = min(effective_max, max(min_weight, raw_weight_1)) + weight_0 = 1.0 # Reference weight for majority class + + # Round weights for consistency and to prevent floating point issues + weight_1 = round(float(weight_1), 3) + weight_0 = round(float(weight_0), 3) + + return { + "0": weight_0, + "1": weight_1, + "support_0": support_0, + "support_1": support_1, + "raw_weight_1": round(float(raw_weight_1), 3), + "calculation_metadata": { + "formula": "focal_loss_with_adaptive_scaling", + "gamma": round(float(gamma), 3), + "alpha": round(float(alpha), 3), + "final_pt": round(float(pt), 4), + "effective_max": round(float(effective_max), 3), + "modulating_factor": round(float(modulating_factor), 4), + "balanced_alpha": round(float(balanced_alpha), 4), + "severity_adjusted": toxicity_type in ['threat', 'identity_hate'], + "boost_factor": round(float(boost_factor), 3), + "constraints_applied": [ + f"max_weight={effective_max}", + f"boost={boost_factor}", + f"numerical_stability=enforced", + f"adaptive_scaling={'enabled' if toxicity_type in ['threat', 'identity_hate'] else 'disabled'}" + ] + } + } + +def get_language_specific_params(lang: str, toxicity_type: str) -> Dict: + """ + Get language and class specific parameters for weight calculation. + Includes focal loss parameters and their adjustments per language/class. + """ + # Default parameters + default_params = { + "max_weight": 15.0, + "min_weight": 0.5, + "boost_factor": 1.0, + "gamma": 2.0, # Default focal loss gamma + "alpha": 0.25 # Default focal loss alpha + } + + # Updated language-specific adjustments based on analysis + lang_adjustments = { + "en": { + "toxic": { + "boost_factor": 1.67, # To achieve ~3.5x weight + "gamma": 2.5 # More focus on hard examples for main class + }, + "threat": { + "max_weight": 15.0, # Absolute maximum cap + "gamma": 3.0, # Higher gamma for severe class + "alpha": 0.3 # Slightly higher alpha for better recall + }, + "identity_hate": { + "max_weight": 5.0, # Reduced from 8.4 + "gamma": 3.0, # Higher gamma for severe class + "alpha": 0.3 # Slightly higher alpha for better recall + }, + "severe_toxic": { + "max_weight": 3.9, # Corrected weight + "gamma": 2.5 # Moderate gamma for balance + } + }, + "tr": { + "threat": { + "max_weight": 12.8, # Aligned with cross-lingual ratio + "gamma": 2.8 # Slightly lower than EN for stability + }, + "identity_hate": { + "max_weight": 6.2, # Adjusted for balance + "gamma": 2.8 # Slightly lower than EN for stability + } + }, + "ru": { + "threat": { + "max_weight": 12.8, # Aligned with cross-lingual ratio + "gamma": 2.8 # Slightly lower than EN for stability + }, + "identity_hate": { + "max_weight": 7.0, # Adjusted for balance + "gamma": 2.8 # Slightly lower than EN for stability + } + }, + "fr": { + "toxic": { + "boost_factor": 1.2, # To achieve ~2.2x weight + "gamma": 2.2 # Lower gamma for better stability + } + } + } + + # Get language-specific params and validate + lang_params = lang_adjustments.get(lang, {}) + class_params = lang_params.get(toxicity_type, {}) + merged_params = {**default_params, **class_params} + + return validate_parameters(merged_params) + +def check_cross_language_consistency(lang_weights: Dict) -> List[str]: + """ + Check for consistency of weights across languages. + Returns a list of warnings for significant disparities. + """ + warnings = [] + baseline = lang_weights['en'] + + for lang in lang_weights: + if lang == 'en': + continue + + for cls in ['threat', 'identity_hate']: + if cls in lang_weights[lang] and cls in baseline: + ratio = lang_weights[lang][cls]['1'] / baseline[cls]['1'] + if ratio > 1.5 or ratio < 0.67: + warning = f"Large {cls} weight disparity: {lang} vs en ({ratio:.2f}x)" + warnings.append(warning) + logging.warning(warning) + + return warnings + +def validate_dataset_balance(df: pd.DataFrame) -> bool: + """ + Validate dataset balance across languages. + Returns False if imbalance exceeds threshold. + """ + sample_counts = df.groupby('lang').size() + cv = sample_counts.std() / sample_counts.mean() + + if cv > 0.15: # 15% threshold for coefficient of variation + logging.error(f"Dataset language imbalance exceeds 15% (CV={cv:.2%})") + for lang, count in sample_counts.items(): + logging.warning(f"{lang}: {count:,} samples ({count/len(df):.1%})") + return False + return True + +def validate_weights(lang_weights: Dict) -> List[str]: + """ + Ensure weights meet multilingual safety criteria. + Validates weight ratios and focal loss parameters across languages. + + Args: + lang_weights: Dictionary of weights per language and class + + Returns: + List of validation warnings + + Raises: + ValueError: If weights violate safety constraints + """ + warnings = [] + + for lang in lang_weights: + for cls in lang_weights[lang]: + w1 = lang_weights[lang][cls]['1'] + w0 = lang_weights[lang][cls]['0'] + + # Check weight ratio sanity + ratio = w1 / w0 + if ratio > 30: + raise ValueError( + f"Dangerous weight ratio {ratio:.1f}x for {lang} {cls}. " + f"Weight_1={w1:.3f}, Weight_0={w0:.3f}" + ) + elif ratio > 20: + warnings.append( + f"High weight ratio {ratio:.1f}x for {lang} {cls}" + ) + + # Check focal parameter boundaries + metadata = lang_weights[lang][cls]['calculation_metadata'] + gamma = metadata.get('gamma', 0.0) + alpha = metadata.get('alpha', 0.0) + + if gamma > 5.0: + raise ValueError( + f"Unsafe gamma={gamma:.1f} for {lang} {cls}. " + f"Must be <= 5.0" + ) + elif gamma > 4.0: + warnings.append( + f"High gamma={gamma:.1f} for {lang} {cls}" + ) + + if alpha > 0.9: + raise ValueError( + f"Unsafe alpha={alpha:.2f} for {lang} {cls}. " + f"Must be < 0.9" + ) + elif alpha > 0.7: + warnings.append( + f"High alpha={alpha:.2f} for {lang} {cls}" + ) + + # Check for combined risk factors + if gamma > 3.0 and ratio > 15: + warnings.append( + f"Risky combination for {lang} {cls}: " + f"gamma={gamma:.1f}, ratio={ratio:.1f}x" + ) + + return warnings + +def compute_language_weights(df: pd.DataFrame) -> Dict: + """ + Compute weights with inter-language normalization to ensure consistent + weighting across languages while preserving relative class relationships. + """ + # Validate dataset balance first + if not validate_dataset_balance(df): + logging.warning("Proceeding with imbalanced dataset - weights may need manual adjustment") + + lang_weights = {} + toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # First pass: calculate raw weights for each language and class + logging.info("\nFirst pass: Calculating raw weights") + for lang in df['lang'].unique(): + logging.info(f"\nProcessing language: {lang}") + lang_df = df[df['lang'] == lang] + lang_weights[lang] = {} + + for col in toxicity_columns: + y = lang_df[col].values.astype(np.int32) + support_0 = int((y == 0).sum()) + support_1 = int((y == 1).sum()) + + params = get_language_specific_params(lang, col) + weights = calculate_safe_weights( + support_0=support_0, + support_1=support_1, + max_weight=params['max_weight'], + min_weight=params['min_weight'], + gamma=params['gamma'], + alpha=params['alpha'], + boost_factor=params['boost_factor'], + lang=lang, + toxicity_type=col + ) + lang_weights[lang][col] = weights + + # Log initial weights + logging.info(f" {col} - Initial weights:") + logging.info(f" Class 0: {weights['0']:.3f}, samples: {support_0:,}") + logging.info(f" Class 1: {weights['1']:.3f}, samples: {support_1:,}") + + # Second pass: normalize weights across languages + logging.info("\nSecond pass: Normalizing weights across languages") + for col in toxicity_columns: + # Find maximum weight for this toxicity type across all languages + max_weight = max( + lang_weights[lang][col]['1'] + for lang in lang_weights + ) + + if max_weight > 0: # Prevent division by zero + logging.info(f"\nNormalizing {col}:") + logging.info(f" Maximum weight across languages: {max_weight:.3f}") + + # Normalize weights for each language + for lang in lang_weights: + original_weight = lang_weights[lang][col]['1'] + + # Normalize and rescale + normalized_weight = (original_weight / max_weight) * 15.0 + + # Update weight while preserving metadata + lang_weights[lang][col]['raw_weight_1'] = original_weight + lang_weights[lang][col]['1'] = round(normalized_weight, 3) + + # Add normalization info to metadata + lang_weights[lang][col]['calculation_metadata'].update({ + 'normalization': { + 'original_weight': round(float(original_weight), 3), + 'max_weight_across_langs': round(float(max_weight), 3), + 'normalization_factor': round(float(15.0 / max_weight), 3) + } + }) + + # Log normalization results + logging.info(f" {lang}: {original_weight:.3f} → {normalized_weight:.3f}") + + # Validate final weights + logging.info("\nValidating final weights:") + for col in toxicity_columns: + weights_range = [ + lang_weights[lang][col]['1'] + for lang in lang_weights + ] + logging.info(f" {col}: range [{min(weights_range):.3f}, {max(weights_range):.3f}]") + + # Validate weights meet safety criteria + validation_warnings = validate_weights(lang_weights) + if validation_warnings: + logging.warning("\nWeight validation warnings:") + for warning in validation_warnings: + logging.warning(f" {warning}") + + # Check cross-language consistency + consistency_warnings = check_cross_language_consistency(lang_weights) + if consistency_warnings: + logging.warning("\nCross-language consistency warnings:") + for warning in consistency_warnings: + logging.warning(f" {warning}") + + return lang_weights + +def main(): + # Load dataset + input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv' + logging.info(f"Loading dataset from {input_file}") + df = pd.read_csv(input_file) + + # Compute weights + lang_weights = compute_language_weights(df) + + # Add metadata + weights_data = { + "metadata": { + "total_samples": len(df), + "language_distribution": df['lang'].value_counts().to_dict(), + "weight_calculation": { + "method": "focal_loss_with_adaptive_scaling", + "parameters": { + "default_max_weight": 15.0, + "default_min_weight": 0.5, + "language_specific_adjustments": True + } + } + }, + "weights": lang_weights + } + + # Save weights + output_file = 'weights/language_class_weights.json' + logging.info(f"\nSaving weights to {output_file}") + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(weights_data, f, indent=2, ensure_ascii=False) + + logging.info("\nWeight calculation complete!") + + # Print summary statistics + logging.info("\nSummary of adjustments made:") + for lang in lang_weights: + for col in ['threat', 'identity_hate']: + if col in lang_weights[lang]: + weight = lang_weights[lang][col]['1'] + raw = lang_weights[lang][col]['raw_weight_1'] + if raw != weight: + logging.info(f"{lang} {col}: Adjusted from {raw:.2f}× to {weight:.2f}×") + +if __name__ == "__main__": + main() diff --git a/analysis/plot_loss_curves.py b/analysis/plot_loss_curves.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d94e759d75a31ce7525981f9a9ec08d304949d --- /dev/null +++ b/analysis/plot_loss_curves.py @@ -0,0 +1,374 @@ +import pandas as pd +import torch +import matplotlib.pyplot as plt +import numpy as np +from datetime import datetime +import logging +from pathlib import Path +from torch.utils.data import DataLoader +import sys +import os +import wandb +from transformers import get_linear_schedule_with_warmup + +# Add project root to path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from model.training_config import TrainingConfig +from model.language_aware_transformer import LanguageAwareTransformer +from model.train import ToxicDataset +from transformers import XLMRobertaTokenizer + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def setup_plot_style(): + """Configure plot styling""" + plt.style.use('seaborn-darkgrid') + plt.rcParams['figure.figsize'] = (12, 12) + plt.rcParams['font.size'] = 12 + +def setup_wandb(): + """Initialize wandb for validation tracking""" + try: + wandb.init( + project="toxic-comment-classification", + name=f"validation-analysis-{datetime.now().strftime('%Y%m%d-%H%M%S')}", + config={ + "analysis_type": "validation_loss", + "timestamp": datetime.now().strftime('%Y%m%d-%H%M%S') + } + ) + logger.info("Initialized wandb logging") + except Exception as e: + logger.error(f"Error initializing wandb: {str(e)}") + raise + +def load_model_and_data(): + """Load the model and validation data""" + try: + # Initialize config with training settings + config = TrainingConfig( + batch_size=16, + num_workers=16, + lr=2e-5, + weight_decay=0.01, + max_grad_norm=1.0, + warmup_ratio=0.1, + label_smoothing=0.01, + + mixed_precision="fp16", + activation_checkpointing=True, + epochs=1 # Number of validation epochs + + ) + + # Load validation data + logger.info("Loading validation and test data...") + val_df = pd.read_csv("dataset/split/val.csv") + test_df = pd.read_csv("dataset/split/test.csv") + combined_df = pd.concat([val_df, test_df]) + tokenizer = XLMRobertaTokenizer.from_pretrained(config.model_name) + combined_dataset = ToxicDataset(combined_df, tokenizer, config, mode='combined') + + + # Create combined dataloader + combined_loader = DataLoader( + combined_dataset, + batch_size=config.batch_size, + shuffle=True, # Enable shuffling + num_workers=config.num_workers, + pin_memory=True, + drop_last=False # Keep all samples + ) + + # Log dataloader config to wandb + if wandb.run is not None: + wandb.config.update({ + 'shuffle': True, + 'drop_last': False, + 'total_validation_steps': len(combined_loader), + 'total_validation_samples': len(combined_dataset) + }) + + + # Load model + logger.info("Loading model...") + model = LanguageAwareTransformer( + num_labels=len(config.toxicity_labels), + model_name=config.model_name + ) + + # Load latest checkpoint + checkpoint_path = Path('weights/toxic_classifier_xlm-roberta-large/pytorch_model.bin') + if checkpoint_path.exists(): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(checkpoint) + logger.info("Loaded model checkpoint") + else: + raise FileNotFoundError("No checkpoint found") + + # Move model to GPU if available + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + + # Setup optimizer + param_groups = config.get_param_groups(model) + optimizer = torch.optim.AdamW(param_groups) + + # Setup scheduler + total_steps = len(combined_loader) * config.epochs + warmup_steps = int(total_steps * config.warmup_ratio) + + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps + ) + + # Initialize gradient scaler for mixed precision + scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision == "fp16") + + # Log model configuration to wandb + if wandb.run is not None: + wandb.config.update({ + 'model_name': config.model_name, + 'batch_size': config.batch_size, + 'learning_rate': config.lr, + 'weight_decay': config.weight_decay, + 'max_grad_norm': config.max_grad_norm, + 'warmup_ratio': config.warmup_ratio, + 'label_smoothing': config.label_smoothing, + 'mixed_precision': config.mixed_precision, + 'num_workers': config.num_workers, + 'activation_checkpointing': config.activation_checkpointing, + 'validation_epochs': config.epochs + }) + + return model, combined_loader, device, optimizer, scheduler, scaler, config + + + except Exception as e: + logger.error(f"Error loading model and data: {str(e)}") + raise + +def collect_validation_losses(model, combined_loader, device, optimizer, scheduler, scaler, config): + """Run validation and collect step losses across multiple epochs""" + try: + logger.warning("This is an analysis run on combined val+test data - model will not be saved or updated") + # Ensure we're in eval mode and no gradients are computed + model.eval() + for param in model.parameters(): + param.requires_grad = False + + all_losses = [] + epoch_losses = [] + + for epoch in range(config.epochs): + logger.info(f"\nStarting validation epoch {epoch+1}/{config.epochs}") + total_loss = 0 + num_batches = len(combined_loader) + epoch_start_time = datetime.now() + + with torch.no_grad(): # Extra safety to ensure no gradients + for step, batch in enumerate(combined_loader): + # Move batch to device + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + + # Forward pass with mixed precision + with torch.cuda.amp.autocast(enabled=config.mixed_precision != "no"): + outputs = model(**batch) + loss = outputs['loss'].item() + + total_loss += loss + + # Calculate running averages + avg_loss = total_loss / (step + 1) + + # Get learning rates + lrs = [group['lr'] for group in optimizer.param_groups] + + # Log to wandb + wandb.log({ + 'val/step_loss': loss, + 'val/running_avg_loss': avg_loss, + 'val/progress': (step + 1) / num_batches * 100, + 'val/learning_rate': lrs[0], # Base learning rate + 'val/batch_size': config.batch_size, + 'val/epoch': epoch + 1, + 'val/global_step': epoch * num_batches + step + }) + + # Log progress + if step % 10 == 0: + elapsed_time = datetime.now() - epoch_start_time + steps_per_sec = (step + 1) / elapsed_time.total_seconds() + remaining_steps = num_batches - (step + 1) + eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0 + + logger.info( + f"Epoch [{epoch+1}/{config.epochs}] " + f"Step [{step+1}/{num_batches}] " + f"Loss: {loss:.4f} " + f"Avg Loss: {avg_loss:.4f} " + f"LR: {lrs[0]:.2e} " + f"ETA: {int(eta_seconds)}s" + ) + + # Calculate epoch statistics + epoch_avg_loss = total_loss / num_batches + epoch_losses.append({ + 'epoch': epoch + 1, + 'avg_loss': epoch_avg_loss, + 'elapsed_time': (datetime.now() - epoch_start_time).total_seconds() + }) + + # Log epoch metrics to wandb + wandb.log({ + 'val/epoch_avg_loss': epoch_avg_loss, + 'val/epoch_number': epoch + 1, + 'val/epoch_time': epoch_losses[-1]['elapsed_time'] + }) + + # Clear GPU memory after each epoch + torch.cuda.empty_cache() + + return epoch_losses + + except Exception as e: + logger.error(f"Error collecting validation losses: {str(e)}") + raise + +def plot_validation_losses(epoch_losses): + """Plot validation epoch losses""" + try: + setup_plot_style() + + # Create figure + fig, ax = plt.subplots() + + # Extract data + epochs = [d['epoch'] for d in epoch_losses] + losses = [d['avg_loss'] for d in epoch_losses] + + # Plot epoch losses + ax.plot(epochs, losses, 'go-', label='Epoch Average Loss', linewidth=2, markersize=8) + + # Add trend line + z = np.polyfit(epochs, losses, 1) + p = np.poly1d(z) + ax.plot(epochs, p(epochs), "r--", alpha=0.8, label='Loss Trend') + + # Customize plot + ax.set_title('Validation Epoch Losses') + ax.set_xlabel('Epoch') + ax.set_ylabel('Average Loss') + ax.legend() + ax.grid(True, linestyle='--', alpha=0.7) + + # Adjust layout + plt.tight_layout() + + # Create output directory if it doesn't exist + output_dir = Path('analysis/plots') + output_dir.mkdir(parents=True, exist_ok=True) + + # Save plot + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + output_path = output_dir / f'validation_losses_{timestamp}.png' + plt.savefig(output_path, dpi=300, bbox_inches='tight') + logger.info(f"Plot saved to {output_path}") + + # Log plot to wandb + wandb.log({ + "val/loss_plot": wandb.Image(str(output_path)) + }) + + # Show plot + plt.show() + + except Exception as e: + logger.error(f"Error plotting validation losses: {str(e)}") + raise + +def calculate_loss_statistics(epoch_losses): + """Calculate and print loss statistics""" + try: + losses = [d['avg_loss'] for d in epoch_losses] + + stats = { + 'Mean Loss': np.mean(losses), + 'Std Loss': np.std(losses), + 'Min Loss': np.min(losses), + 'Max Loss': np.max(losses), + 'Best Epoch': epoch_losses[np.argmin(losses)]['epoch'] + } + + # Log statistics to wandb + wandb.log({ + 'val/mean_loss': stats['Mean Loss'], + 'val/std_loss': stats['Std Loss'], + 'val/min_loss': stats['Min Loss'], + 'val/max_loss': stats['Max Loss'], + 'val/best_epoch': stats['Best Epoch'] + }) + + # Print statistics + print("\nValidation Loss Statistics:") + for metric_name, value in stats.items(): + if metric_name == 'Best Epoch': + print(f"{metric_name}: {int(value)}") + else: + print(f"{metric_name}: {value:.4f}") + + return stats + + except Exception as e: + logger.error(f"Error calculating statistics: {str(e)}") + raise + +def main(): + try: + # Initialize wandb + setup_wandb() + + # Load model and data + logger.info("Loading model and data...") + model, combined_loader, device, optimizer, scheduler, scaler, config = load_model_and_data() + + + # Collect validation losses + logger.info("Collecting validation losses...") + epoch_losses = collect_validation_losses( + model, combined_loader, device, optimizer, scheduler, scaler, config + ) + + + # Plot losses + logger.info("Plotting validation losses...") + plot_validation_losses(epoch_losses) + + # Calculate and print statistics + logger.info("Calculating statistics...") + calculate_loss_statistics(epoch_losses) + + except Exception as e: + logger.error(f"Error in main: {str(e)}") + raise + finally: + # Clean up + torch.cuda.empty_cache() + # Finish wandb run + wandb.finish() + +if __name__ == "__main__": + try: + main() + except Exception as e: + logger.error(f"Script failed: {str(e)}") + raise \ No newline at end of file diff --git a/analysis/plot_roc_curves.py b/analysis/plot_roc_curves.py new file mode 100644 index 0000000000000000000000000000000000000000..41e59e11b29c1a041b9519e2cb9f9500dbe66c43 --- /dev/null +++ b/analysis/plot_roc_curves.py @@ -0,0 +1,163 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import roc_curve, auc +import os +import json +from pathlib import Path + +def plot_roc_curves(predictions_path, output_dir=None): + """ + Plot ROC curves from model predictions + + Args: + predictions_path (str): Path to the .npz file containing predictions + output_dir (str, optional): Directory to save plots. If None, will use same directory as predictions + """ + # Load predictions + data = np.load(predictions_path) + predictions = data['predictions'] + labels = data['labels'] + langs = data['langs'] + + # Create output directory + if output_dir is None: + output_dir = os.path.dirname(predictions_path) + plots_dir = os.path.join(output_dir, 'plots') + os.makedirs(plots_dir, exist_ok=True) + + # Define toxicity types + toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Define language mapping + id_to_lang = { + 0: 'English (en)', + 1: 'Russian (ru)', + 2: 'Turkish (tr)', + 3: 'Spanish (es)', + 4: 'French (fr)', + 5: 'Italian (it)', + 6: 'Portuguese (pt)' + } + + # Plot overall ROC curves (one per class) + plt.figure(figsize=(10, 8)) + for i, class_name in enumerate(toxicity_types): + fpr, tpr, _ = roc_curve(labels[:, i], predictions[:, i]) + roc_auc = auc(fpr, tpr) + + plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})') + + plt.plot([0, 1], [0, 1], 'k--', label='Random') + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC Curves - All Classes') + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + plt.grid(True) + plt.tight_layout() + plt.savefig(os.path.join(plots_dir, 'roc_all_classes.png'), dpi=300, bbox_inches='tight') + plt.close() + + # Plot per-class ROC curves with confidence intervals + n_bootstrap = 1000 + n_classes = len(toxicity_types) + + for i, class_name in enumerate(toxicity_types): + plt.figure(figsize=(8, 6)) + + # Calculate main ROC curve + fpr, tpr, _ = roc_curve(labels[:, i], predictions[:, i]) + roc_auc = auc(fpr, tpr) + + # Plot main curve + plt.plot(fpr, tpr, 'b-', label=f'ROC (AUC = {roc_auc:.3f})') + + # Bootstrap for confidence intervals + tprs = [] + aucs = [] + mean_fpr = np.linspace(0, 1, 100) + + for _ in range(n_bootstrap): + # Bootstrap sample indices + indices = np.random.randint(0, len(labels), len(labels)) + if len(np.unique(labels[indices, i])) < 2: + continue + + # Calculate ROC curve + fpr, tpr, _ = roc_curve(labels[indices, i], predictions[indices, i]) + + # Interpolate TPR at mean FPR points + interp_tpr = np.interp(mean_fpr, fpr, tpr) + interp_tpr[0] = 0.0 + tprs.append(interp_tpr) + aucs.append(auc(fpr, tpr)) + + # Calculate confidence intervals + tprs = np.array(tprs) + mean_tpr = np.mean(tprs, axis=0) + std_tpr = np.std(tprs, axis=0) + + tprs_upper = np.minimum(mean_tpr + std_tpr, 1) + tprs_lower = np.maximum(mean_tpr - std_tpr, 0) + + # Plot confidence interval + plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, + label=f'±1 std. dev.') + + # Calculate AUC confidence interval + auc_mean = np.mean(aucs) + auc_std = np.std(aucs) + plt.plot([], [], ' ', label=f'AUC = {auc_mean:.3f} ± {auc_std:.3f}') + + plt.plot([0, 1], [0, 1], 'k--', label='Random') + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title(f'ROC Curve - {class_name}') + plt.legend(loc='lower right') + plt.grid(True) + plt.tight_layout() + plt.savefig(os.path.join(plots_dir, f'roc_{class_name}.png'), dpi=300) + plt.close() + + # Plot per-language ROC curves (for toxic class) + plt.figure(figsize=(10, 8)) + for lang_id, lang_name in id_to_lang.items(): + # Get samples for this language + lang_mask = langs == lang_id + if lang_mask.sum() > 0 and len(np.unique(labels[lang_mask, 0])) > 1: + fpr, tpr, _ = roc_curve(labels[lang_mask, 0], predictions[lang_mask, 0]) + roc_auc = auc(fpr, tpr) + plt.plot(fpr, tpr, label=f'{lang_name} (AUC = {roc_auc:.3f})') + + plt.plot([0, 1], [0, 1], 'k--', label='Random') + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC Curves by Language - Toxic Class') + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + plt.grid(True) + plt.tight_layout() + plt.savefig(os.path.join(plots_dir, 'roc_by_language.png'), dpi=300, bbox_inches='tight') + plt.close() + + print(f"\nROC curves have been saved to {plots_dir}") + print("\nGenerated plots:") + print("1. roc_all_classes.png - ROC curves for all toxicity classes") + print("2. roc_[class_name].png - Individual ROC curves with confidence intervals for each class") + print("3. roc_by_language.png - ROC curves for each language (toxic class)") + +if __name__ == '__main__': + # Use the latest evaluation results + eval_dir = 'evaluation_results' + if os.path.exists(eval_dir): + # Find most recent evaluation directory + eval_dirs = sorted([d for d in os.listdir(eval_dir) if d.startswith('eval_')], reverse=True) + if eval_dirs: + latest_eval = os.path.join(eval_dir, eval_dirs[0]) + predictions_path = os.path.join(latest_eval, 'predictions.npz') + if os.path.exists(predictions_path): + plot_roc_curves(predictions_path) + else: + print(f"No predictions file found in {latest_eval}") + else: + print(f"No evaluation directories found in {eval_dir}") + else: + print(f"Evaluation directory {eval_dir} not found") \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a513fa3b8921234e69fe68d90243df38dbc609ec --- /dev/null +++ b/app.py @@ -0,0 +1,262 @@ +import gradio as gr +import torch +import numpy as np +import os +import json +from model.inference_optimized import OptimizedToxicityClassifier +import matplotlib.pyplot as plt +from typing import List, Dict +import langid +import pandas as pd + +# Configure paths +ONNX_MODEL_PATH = os.environ.get("ONNX_MODEL_PATH", "weights/toxic_classifier.onnx") +PYTORCH_MODEL_PATH = os.environ.get("PYTORCH_MODEL_PATH", "weights/toxic_classifier_xlm-roberta-large/pytorch_model.bin") +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +# Supported languages +SUPPORTED_LANGUAGES = { + 'en': 'English', + 'ru': 'Russian', + 'tr': 'Turkish', + 'es': 'Spanish', + 'fr': 'French', + 'it': 'Italian', + 'pt': 'Portuguese' +} + +# Initialize classifier +try: + if os.path.exists(ONNX_MODEL_PATH): + classifier = OptimizedToxicityClassifier(onnx_path=ONNX_MODEL_PATH, device=DEVICE) + print(f"Loaded ONNX model from {ONNX_MODEL_PATH}") + else: + classifier = OptimizedToxicityClassifier(pytorch_path=PYTORCH_MODEL_PATH, device=DEVICE) + print(f"Loaded PyTorch model from {PYTORCH_MODEL_PATH}") +except Exception as e: + print(f"Error loading model: {str(e)}") + classifier = None + +def detect_language(text: str) -> str: + """Detect language of input text""" + try: + lang, _ = langid.classify(text) + return lang if lang in SUPPORTED_LANGUAGES else 'en' + except: + return 'en' + +def predict_toxicity(text: str, selected_language: str = None) -> Dict: + """Predict toxicity of input text""" + if not text or not text.strip(): + return { + "error": "Please enter some text to analyze.", + "html_result": "
Please enter some text to analyze.
" + } + + if classifier is None: + return { + "error": "Model not loaded. Please check logs.", + "html_result": "
Model not loaded. Please check logs.
" + } + + # Detect language if not specified + if not selected_language or selected_language == "Auto-detect": + lang_code = detect_language(text) + detected = True + else: + # Convert from display name to code + lang_code = next((code for code, name in SUPPORTED_LANGUAGES.items() + if name == selected_language), 'en') + detected = False + + # Run prediction + try: + results = classifier.predict([text], langs=[lang_code])[0] + + # Format probabilities for display + probs = results["probabilities"] + sorted_categories = sorted( + [(label, probs[label]) for label in probs], + key=lambda x: x[1], + reverse=True + ) + + # Create bar chart + fig, ax = plt.subplots(figsize=(10, 6)) + labels = [label.replace('_', ' ').title() for label, _ in sorted_categories] + values = [prob * 100 for _, prob in sorted_categories] + colors = ['#ff6b6b' if val >= 50 else '#74c0fc' for val in values] + + ax.barh(labels, values, color=colors) + ax.set_xlim(0, 100) + ax.set_xlabel('Probability (%)') + ax.set_title('Toxicity Analysis') + ax.grid(axis='x', linestyle='--', alpha=0.7) + + # Annotate values + for i, v in enumerate(values): + ax.text(v + 1, i, f'{v:.1f}%', va='center') + + # Create HTML result + lang_display = SUPPORTED_LANGUAGES.get(lang_code, lang_code) + overall_result = "TOXIC" if results["is_toxic"] else "NON-TOXIC" + result_color = "#ff6b6b" if results["is_toxic"] else "#66d9e8" + + html_result = f""" +
+

Analysis Result: {overall_result}

+

Language: {lang_display} {'(detected)' if detected else ''}

+
+
+ + + + + + + """ + + # Add rows for each toxicity category + for label, prob in sorted_categories: + formatted_label = label.replace('_', ' ').title() + status = "DETECTED" if prob >= 0.5 else "Not Detected" + status_color = "#ff6b6b" if prob >= 0.5 else "#66d9e8" + prob_percent = f"{prob * 100:.1f}%" + + html_result += f""" + + + + + + """ + + html_result += "
CategoryProbabilityStatus
{formatted_label}{prob_percent}{status}
" + + # Add detected categories if toxic + if results["is_toxic"]: + toxic_categories = [cat.replace('_', ' ').title() for cat in results["toxic_categories"]] + categories_list = ", ".join(toxic_categories) + html_result += f""" +
+

Detected toxic categories: {categories_list}

+
+ """ + + return { + "prediction": results, + "html_result": html_result, + "fig": fig + } + + except Exception as e: + import traceback + traceback.print_exc() + return { + "error": f"Error processing text: {str(e)}", + "html_result": f"
Error processing text: {str(e)}
" + } + +def create_app(): + """Create and configure the Gradio interface""" + # Create language dropdown options + language_options = ["Auto-detect"] + list(SUPPORTED_LANGUAGES.values()) + + # Define the interface + with gr.Blocks(css=""" + .error { color: #ff6b6b; font-weight: bold; padding: 10px; border: 1px solid #ff6b6b; } + .container { margin: 0 auto; max-width: 900px; } + .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } + .example-text { font-style: italic; color: #666; } + """) as app: + gr.Markdown(""" + # Multilingual Toxic Comment Classifier + This app analyzes text for different types of toxicity across multiple languages. + Enter your text, select a language (or let it auto-detect), and click 'Analyze'. + + Supported languages: English, Russian, Turkish, Spanish, French, Italian, Portuguese + """) + + with gr.Row(): + with gr.Column(scale=3): + text_input = gr.Textbox( + label="Enter text to analyze", + placeholder="Type or paste text here...", + lines=5 + ) + lang_dropdown = gr.Dropdown( + choices=language_options, + value="Auto-detect", + label="Language" + ) + analyze_btn = gr.Button("Analyze", variant="primary") + + with gr.Column(scale=2): + gr.Markdown("### Example texts:") + with gr.Accordion("English example"): + en_example_btn = gr.Button("Use English example") + with gr.Accordion("Spanish example"): + es_example_btn = gr.Button("Use Spanish example") + with gr.Accordion("French example"): + fr_example_btn = gr.Button("Use French example") + + # Examples + en_example_text = "You are such an idiot, nobody likes your stupid content." + es_example_text = "Eres un completo idiota y nadie te quiere." + fr_example_text = "Tu es tellement stupide, personne n'aime ton contenu minable." + + en_example_btn.click( + lambda: en_example_text, + outputs=text_input + ) + es_example_btn.click( + lambda: es_example_text, + outputs=text_input + ) + fr_example_btn.click( + lambda: fr_example_text, + outputs=text_input + ) + + # Output components + result_html = gr.HTML(label="Analysis Result") + plot_output = gr.Plot(label="Toxicity Probabilities") + + # Set up event handling + analyze_btn.click( + predict_toxicity, + inputs=[text_input, lang_dropdown], + outputs=[result_html, plot_output] + ) + + # Also analyze on pressing Enter in the text box + text_input.submit( + predict_toxicity, + inputs=[text_input, lang_dropdown], + outputs=[result_html, plot_output] + ) + + gr.Markdown(""" + ### About this model + This model classifies text into six toxicity categories: + - **Toxic**: General toxicity + - **Severe Toxic**: Extreme toxicity + - **Obscene**: Obscene content + - **Threat**: Threatening content + - **Insult**: Insulting content + - **Identity Hate**: Identity-based hate + + Built using XLM-RoBERTa with language-aware fine-tuning. + """) + + return app + +# Launch the app when script is run directly +if __name__ == "__main__": + # Create and launch the app + app = create_app() + app.launch( + server_name="0.0.0.0", # Bind to all interfaces + server_port=7860, # Default Gradio port + share=True # Generate public link + ) \ No newline at end of file diff --git a/augmentation/balance_english.py b/augmentation/balance_english.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a0424246395d2e0155f533534184b337495003 --- /dev/null +++ b/augmentation/balance_english.py @@ -0,0 +1,237 @@ +import os +import torch + +# Configure CPU and thread settings FIRST, before any other imports +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' +os.environ['TF_CPU_ENABLE_AVX2'] = '1' +os.environ['TF_CPU_ENABLE_AVX512F'] = '1' +os.environ['TF_CPU_ENABLE_AVX512_VNNI'] = '1' +os.environ['TF_CPU_ENABLE_FMA'] = '1' +os.environ['MKL_NUM_THREADS'] = '80' +os.environ['OMP_NUM_THREADS'] = '80' + +# Set PyTorch thread configurations once +torch.set_num_threads(80) +torch.set_num_interop_threads(10) + +# Now import everything else +import pandas as pd +import numpy as np +from pathlib import Path +import logging +from datetime import datetime +import sys +from toxic_augment import ToxicAugmenter +import json + +# Configure logging +log_dir = Path("logs") +log_dir.mkdir(exist_ok=True) + +timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") +log_file = log_dir / f"balance_english_{timestamp}.log" + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s | %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler(log_file) + ] +) + +logger = logging.getLogger(__name__) + +def analyze_label_distribution(df, lang='en'): + """Analyze label distribution for a specific language""" + labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + lang_df = df[df['lang'] == lang] + total = len(lang_df) + + if total == 0: + logger.warning(f"No samples found for language {lang.upper()}.") + return {} + + logger.info(f"\nLabel Distribution for {lang.upper()}:") + logger.info("-" * 50) + dist = {} + for label in labels: + count = lang_df[label].sum() + percentage = (count / total) * 100 + dist[label] = {'count': int(count), 'percentage': percentage} + logger.info(f"{label}: {count:,} ({percentage:.2f}%)") + return dist + +def analyze_language_distribution(df): + """Analyze current language distribution""" + lang_dist = df['lang'].value_counts() + logger.info("\nCurrent Language Distribution:") + logger.info("-" * 50) + for lang, count in lang_dist.items(): + logger.info(f"{lang}: {count:,} comments ({count/len(df)*100:.2f}%)") + return lang_dist + +def calculate_required_samples(df): + """Calculate how many English samples we need to generate""" + lang_counts = df['lang'].value_counts() + target_count = lang_counts.max() # Use the largest language count as target + en_count = lang_counts.get('en', 0) + required_samples = target_count - en_count + + logger.info(f"\nTarget count per language: {target_count:,}") + logger.info(f"Current English count: {en_count:,}") + logger.info(f"Required additional English samples: {required_samples:,}") + + return required_samples + +def generate_balanced_samples(df, required_samples): + """Generate samples maintaining original class distribution ratios""" + logger.info("\nGenerating balanced samples...") + + # Get English samples + en_df = df[df['lang'] == 'en'] + labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Calculate target counts for each label + target_counts = {} + for label in labels: + count = en_df[label].sum() + ratio = count / len(en_df) + target_count = int(ratio * required_samples) + target_counts[label] = target_count + logger.info(f"Target count for {label}: {target_count:,}") + + augmented_samples = [] + augmenter = ToxicAugmenter() + total_generated = 0 + + # Generate samples for each label + for label, target_count in target_counts.items(): + if target_count == 0: + continue + + logger.info(f"\nGenerating {target_count:,} samples for {label}") + + # Get seed texts with this label + seed_texts = en_df[en_df[label] == 1]['comment_text'].tolist() + + if not seed_texts: + logger.warning(f"No seed texts found for {label}, skipping...") + continue + + # Generate samples with 5-minute timeout + new_samples = augmenter.augment_dataset( + target_samples=target_count, + label=label, # Using single label instead of label_combo + seed_texts=seed_texts, + timeout_minutes=5 + ) + + if new_samples is not None and not new_samples.empty: + augmented_samples.append(new_samples) + total_generated += len(new_samples) + + # Log progress + logger.info(f"✓ Generated {len(new_samples):,} samples") + logger.info(f"Progress: {total_generated:,}/{required_samples:,}") + + # Check if we have reached our global required samples + if total_generated >= required_samples: + logger.info("Reached required sample count, stopping generation") + break + + # Combine all generated samples + if augmented_samples: + augmented_df = pd.concat(augmented_samples, ignore_index=True) + augmented_df['lang'] = 'en' + + # Ensure we don't exceed the required sample count + if len(augmented_df) > required_samples: + logger.info(f"Trimming excess samples from {len(augmented_df):,} to {required_samples:,}") + augmented_df = augmented_df.head(required_samples) + + # Log final class distribution + logger.info("\nFinal class distribution in generated samples:") + for label in labels: + count = augmented_df[label].sum() + percentage = (count / len(augmented_df)) * 100 + logger.info(f"{label}: {count:,} ({percentage:.2f}%)") + + # Also log clean samples + clean_count = len(augmented_df[augmented_df[labels].sum(axis=1) == 0]) + clean_percentage = (clean_count / len(augmented_df)) * 100 + logger.info(f"Clean samples: {clean_count:,} ({clean_percentage:.2f}%)") + + return augmented_df + else: + raise Exception("Failed to generate any valid samples") + +def balance_english_data(): + """Main function to balance English data with other languages""" + try: + # Load dataset + input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv' + logger.info(f"Loading dataset from {input_file}") + df = pd.read_csv(input_file) + + # Analyze current distribution + logger.info("\nAnalyzing current distribution...") + initial_dist = analyze_language_distribution(df) + initial_label_dist = analyze_label_distribution(df, 'en') + + # Calculate required samples + required_samples = calculate_required_samples(df) + + if required_samples <= 0: + logger.info("English data is already balanced. No augmentation needed.") + return + + # Generate balanced samples + augmented_df = generate_balanced_samples(df, required_samples) + + # Merge with original dataset + logger.info("\nMerging datasets...") + output_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_BALANCED.csv' + + # Combine datasets + combined_df = pd.concat([df, augmented_df], ignore_index=True) + + # Save balanced dataset + combined_df.to_csv(output_file, index=False) + logger.info(f"\nSaved balanced dataset to {output_file}") + + # Final distribution check + logger.info("\nFinal distribution after balancing:") + final_dist = analyze_language_distribution(combined_df) + final_label_dist = analyze_label_distribution(combined_df, 'en') + + # Save distribution statistics + stats = { + 'timestamp': timestamp, + 'initial_distribution': { + 'languages': initial_dist.to_dict(), + 'english_labels': initial_label_dist + }, + 'final_distribution': { + 'languages': final_dist.to_dict(), + 'english_labels': final_label_dist + }, + 'samples_generated': len(augmented_df), + 'total_samples': len(combined_df) + } + + stats_file = f'logs/balance_stats_{timestamp}.json' + with open(stats_file, 'w') as f: + json.dump(stats, f, indent=2) + logger.info(f"\nSaved balancing statistics to {stats_file}") + + except Exception as e: + logger.error(f"Error during balancing: {str(e)}") + raise + +def main(): + balance_english_data() + +if __name__ == "__main__": + logger.info("Starting English data balancing process...") + main() \ No newline at end of file diff --git a/augmentation/threat_augment.py b/augmentation/threat_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..30b00f2648dc63a371f205132bdbd30ee68d773b --- /dev/null +++ b/augmentation/threat_augment.py @@ -0,0 +1,379 @@ +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig +) +from langdetect import detect +import pandas as pd +import numpy as np +from tqdm import tqdm +from pathlib import Path +import logging +import gc +from typing import List +import json +from datetime import datetime, timedelta +import time +import sys +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression +import joblib + +# Create log directories +log_dir = Path("logs") +log_dir.mkdir(exist_ok=True) + +# Get timestamp for log file +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +log_file = log_dir / f"generation_{timestamp}.log" + +# Configure logging once at the start +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s | %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler(log_file) + ] +) + +logger = logging.getLogger(__name__) +logger.info(f"Starting new run. Log file: {log_file}") + +def log_separator(message: str = ""): + """Print a separator line with optional message""" + if message: + logger.info("\n" + "="*40 + f" {message} " + "="*40) + else: + logger.info("\n" + "="*100) + +class FastThreatValidator: + """Fast threat validation using logistic regression""" + def __init__(self, model_path: str = "weights/threat_validator.joblib"): + self.model_path = model_path + if Path(model_path).exists(): + logger.info("Loading fast threat validator...") + model_data = joblib.load(model_path) + self.vectorizer = model_data['vectorizer'] + self.model = model_data['model'] + logger.info("✓ Fast validator loaded") + else: + logger.info("Training fast threat validator...") + self._train_validator() + logger.info("✓ Fast validator trained and saved") + + def _train_validator(self): + """Train a simple logistic regression model for threat detection""" + # Load training data + train_df = pd.read_csv("dataset/split/train.csv") + + # Prepare data + X = train_df['comment_text'].fillna('') + y = train_df['threat'] + + # Create and fit vectorizer + self.vectorizer = TfidfVectorizer( + max_features=10000, + ngram_range=(1, 2), + strip_accents='unicode', + min_df=2 + ) + X_vec = self.vectorizer.fit_transform(X) + + # Train model + self.model = LogisticRegression( + C=1.0, + class_weight='balanced', + max_iter=200, + n_jobs=-1 + ) + self.model.fit(X_vec, y) + + # Save model + joblib.dump({ + 'vectorizer': self.vectorizer, + 'model': self.model + }, self.model_path) + + def validate(self, texts: List[str], threshold: float = 0.6) -> List[bool]: + """Validate texts using the fast model""" + # Vectorize texts + X = self.vectorizer.transform(texts) + + # Get probabilities + probs = self.model.predict_proba(X)[:, 1] + + # Return boolean mask + return probs >= threshold + +class ThreatAugmenter: + def __init__(self, seed_samples_path: str = "dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv"): + log_separator("INITIALIZATION") + + # Use global log file + self.log_file = log_file + + # Initialize generation buffer + self.generation_buffer = [] + self.buffer_size = 100 # Flush buffer every 100 entries + + # Multi-GPU setup + self.num_gpus = torch.cuda.device_count() + if self.num_gpus > 0: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + logger.info(f"Found {self.num_gpus} GPUs:") + for i in range(self.num_gpus): + mem = torch.cuda.get_device_properties(i).total_memory / 1024**3 + logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") + + # Load models + log_separator("LOADING MODELS") + logger.info("Loading Mistral-7B...") + + # Configure model for multi-GPU + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True + ) + + self.llm = AutoModelForCausalLM.from_pretrained( + "mistralai/Mistral-7B-Instruct-v0.3", + device_map="balanced", # Ensures proper dual GPU usage + torch_dtype=torch.float16, + quantization_config=quantization_config, + max_memory={0: "22GB", 1: "22GB"} # Explicitly set memory limits for each GPU + ) + + self.llm_tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-7B-Instruct-v0.3", + padding_side="left", + use_fast=True + ) + self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token + logger.info("✓ Mistral-7B loaded") + + # Initialize fast validator + self.validator = FastThreatValidator() + logger.info("✓ Fast validator initialized") + + # Load and preprocess data + log_separator("LOADING DATA") + logger.info(f"Loading samples from {seed_samples_path}") + self.seed_df = pd.read_csv(seed_samples_path) + self.en_threat_samples = self.seed_df[ + (self.seed_df['lang'] == 'en') & + (self.seed_df['threat'] == 1) + ] + logger.info(f"Found {len(self.en_threat_samples)} EN threat samples") + + # Optimize batch processing + self.max_batch_size = 48 # Increased batch size + self.prefetch_factor = 4 + self.num_workers = 8 + + # Memory cleanup + torch.cuda.empty_cache() + gc.collect() + log_separator() + + def generate_prompt(self, seed_text: str) -> str: + """Generate a prompt for realistic threatening content""" + return f"""[INST] Generate a threatening social media comment that is realistic and intimidating. + +Reference example (generate something different): {seed_text} + +Requirements: +- Make it genuinely threatening +- Use authentic social media language +- Keep it under 50 words +- Must be different from example + +Generate ONLY the comment: [/INST]""" + + def flush_buffer(self): + """Flush the generation buffer to disk""" + if self.generation_buffer: + try: + with open(self.log_file, 'a', encoding='utf-8') as f: + for entry in self.generation_buffer: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + self.generation_buffer = [] + except Exception as e: + logger.error(f"Failed to flush buffer: {str(e)}") + + def log_generation(self, seed_text: str, prompt: str, generated_text: str, is_valid: bool): + """Buffer log generation details""" + log_entry = { + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "seed_text": seed_text, + "prompt": prompt, + "generated_text": generated_text, + "is_valid": is_valid + } + + self.generation_buffer.append(log_entry) + + # Flush buffer if it reaches the size limit + if len(self.generation_buffer) >= self.buffer_size: + self.flush_buffer() + + def generate_samples(self, prompts: List[str], seed_texts: List[str]) -> List[str]: + try: + with torch.amp.autocast('cuda', dtype=torch.float16): + inputs = self.llm_tokenizer(prompts, return_tensors="pt", padding=True, + truncation=True, max_length=256).to(self.llm.device) + + outputs = self.llm.generate( + **inputs, + max_new_tokens=32, + temperature=0.95, + do_sample=True, + top_p=0.92, + top_k=50, + num_return_sequences=1, + repetition_penalty=1.15, + pad_token_id=self.llm_tokenizer.pad_token_id, + eos_token_id=self.llm_tokenizer.eos_token_id + ) + + texts = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=False) + cleaned_texts = [] + valid_count = 0 + + # Process responses with minimal logging + for idx, text in enumerate(texts): + if "[/INST]" in text and "" in text: + response = text.split("[/INST]")[1].split("")[0].strip() + response = response.strip().strip('"').strip("'") + + word_count = len(response.split()) + if (word_count >= 3 and word_count <= 50 and + not any(x in response.lower() for x in [ + "generate", "requirements:", "reference", + "[inst]", "example" + ])): + cleaned_texts.append(response) + valid_count += 1 + + # Log only summary statistics + if valid_count > 0: + logger.info(f"\nBatch Success: {valid_count}/{len(texts)} ({valid_count/len(texts)*100:.1f}%)") + + return cleaned_texts + + except Exception as e: + logger.error(f"Generation error: {str(e)}") + return [] + + def validate_toxicity(self, texts: List[str]) -> torch.Tensor: + """Validate texts using fast logistic regression""" + if not texts: + return torch.zeros(0, dtype=torch.bool) + + # Get validation mask from fast validator + validation_mask = self.validator.validate(texts) + + # Convert to torch tensor + return torch.tensor(validation_mask, dtype=torch.bool, device=self.llm.device) + + def validate_language(self, texts: List[str]) -> List[bool]: + """Simple language validation""" + return [detect(text) == 'en' for text in texts] + + def augment_dataset(self, target_samples: int = 500, batch_size: int = 32): + """Main augmentation loop with progress bar and CSV saving""" + try: + start_time = time.time() + logger.info(f"Starting generation: target={target_samples}, batch_size={batch_size}") + generated_samples = [] + stats = { + "total_attempts": 0, + "valid_samples": 0, + "batch_times": [] + } + + # Create output directory if it doesn't exist + output_dir = Path("dataset/augmented") + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate timestamp for the filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = output_dir / f"threat_augmented_{timestamp}.csv" + + # Initialize progress bar + pbar = tqdm(total=target_samples, + desc="Generating samples", + unit="samples", + ncols=100, + bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') + + while len(generated_samples) < target_samples: + batch_start = time.time() + + seed_texts = self.en_threat_samples['comment_text'].sample(batch_size).tolist() + prompts = [self.generate_prompt(text) for text in seed_texts] + new_samples = self.generate_samples(prompts, seed_texts) + + if not new_samples: + continue + + # Update statistics + batch_time = time.time() - batch_start + stats["batch_times"].append(batch_time) + stats["total_attempts"] += len(new_samples) + prev_len = len(generated_samples) + generated_samples.extend(new_samples) + stats["valid_samples"] = len(generated_samples) + + # Update progress bar + pbar.update(len(generated_samples) - prev_len) + + # Calculate and display success rate periodically + if len(stats["batch_times"]) % 10 == 0: # Every 10 batches + success_rate = (stats["valid_samples"] / stats["total_attempts"]) * 100 + avg_batch_time = sum(stats["batch_times"][-20:]) / min(len(stats["batch_times"]), 20) + pbar.set_postfix({ + 'Success Rate': f'{success_rate:.1f}%', + 'Batch Time': f'{avg_batch_time:.2f}s' + }) + + # Cleanup + if len(generated_samples) % (batch_size * 5) == 0: + torch.cuda.empty_cache() + gc.collect() + + # Close progress bar + pbar.close() + + # Create DataFrame and save to CSV + df = pd.DataFrame({ + 'text': generated_samples[:target_samples], + 'label': 1, # These are all threat samples + 'source': 'augmented', + 'timestamp': timestamp + }) + + # Save to CSV + df.to_csv(output_file, index=False) + logger.info(f"\nSaved {len(df)} samples to {output_file}") + + # Final stats + total_time = str(timedelta(seconds=int(time.time() - start_time))) + logger.info(f"Generation complete: {len(generated_samples)} samples generated in {total_time}") + + return df + + except Exception as e: + logger.error(f"Generation failed: {str(e)}") + raise + +if __name__ == "__main__": + torch.cuda.empty_cache() + gc.collect() + + augmenter = ThreatAugmenter() + augmented_df = augmenter.augment_dataset(target_samples=500) \ No newline at end of file diff --git a/augmentation/toxic_augment.py b/augmentation/toxic_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..ce62dce2ebaaa5abf8c81839c2f77301620f4f8c --- /dev/null +++ b/augmentation/toxic_augment.py @@ -0,0 +1,439 @@ +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig +) +import pandas as pd +import numpy as np +from tqdm import tqdm +from pathlib import Path +import logging +import gc +from typing import List, Dict +import json +from datetime import datetime +import time +import sys +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression +import joblib +import random + +# Create log directories +log_dir = Path("logs") +log_dir.mkdir(exist_ok=True) + +# Get timestamp for log file +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +log_file = log_dir / f"generation_{timestamp}.log" + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s | %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler(log_file) + ] +) + +logger = logging.getLogger(__name__) +logger.info(f"Starting new run. Log file: {log_file}") + +class FastToxicValidator: + """Fast toxicity validation using logistic regression""" + def __init__(self, model_path: str = "weights/toxic_validator.joblib"): + self.model_path = model_path + if Path(model_path).exists(): + logger.info("Loading fast toxic validator...") + model_data = joblib.load(model_path) + self.vectorizers = model_data['vectorizers'] + self.models = model_data['models'] + logger.info("✓ Fast validator loaded") + else: + logger.info("Training fast toxic validator...") + self._train_validator() + logger.info("✓ Fast validator trained and saved") + + def _train_validator(self): + """Train logistic regression models for each toxicity type""" + # Load training data + train_df = pd.read_csv("dataset/split/train.csv") + + # Labels to validate + labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + self.vectorizers = {} + self.models = {} + + # Train a model for each label + for label in labels: + # Create and fit vectorizer + vectorizer = TfidfVectorizer( + max_features=10000, + ngram_range=(1, 2), + strip_accents='unicode', + min_df=2 + ) + X = vectorizer.fit_transform(train_df['comment_text'].fillna('')) + y = train_df[label] + + # Train model + model = LogisticRegression( + C=1.0, + class_weight='balanced', + max_iter=200, + n_jobs=-1 + ) + model.fit(X, y) + + self.vectorizers[label] = vectorizer + self.models[label] = model + + # Save models + joblib.dump({ + 'vectorizers': self.vectorizers, + 'models': self.models + }, self.model_path) + + def get_probabilities(self, texts: List[str], label: str) -> np.ndarray: + """Get raw probabilities for a specific label""" + X = self.vectorizers[label].transform(texts) + return self.models[label].predict_proba(X)[:, 1] + + def validate(self, texts: List[str], label: str, threshold: float = 0.5) -> List[bool]: + """Validate texts using the fast model with a lower threshold of 0.5""" + # Vectorize texts + X = self.vectorizers[label].transform(texts) + + # Get probabilities + probs = self.models[label].predict_proba(X)[:, 1] + + # Return boolean mask with lower threshold + return probs >= threshold + +class ToxicAugmenter: + def __init__(self): + logger.info("Initializing ToxicAugmenter...") + + # Initialize generation buffer + self.generation_buffer = [] + self.buffer_size = 100 + + # Multi-GPU setup + self.num_gpus = torch.cuda.device_count() + if self.num_gpus > 0: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + logger.info(f"Found {self.num_gpus} GPUs:") + for i in range(self.num_gpus): + mem = torch.cuda.get_device_properties(i).total_memory / 1024**3 + logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") + + # Load models with optimized settings + logger.info("Loading Mistral-7B...") + + # Configure model for multi-GPU with optimized settings + quantization_config = BitsAndBytesConfig( + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True + ) + + # Initialize tokenizer first + self.llm_tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-7B-Instruct-v0.3", + padding_side="left", + use_fast=True, + model_max_length=512, # Limit context size for faster processing + pad_token="", # Set pad token explicitly + add_eos_token=True, + add_bos_token=True + ) + + # Set pad token id explicitly + self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token + self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id + + # Now initialize model with tokenizer's pad token + self.llm = AutoModelForCausalLM.from_pretrained( + "mistralai/Mistral-7B-Instruct-v0.3", + device_map="balanced", + torch_dtype=torch.float16, + quantization_config=quantization_config, + max_memory={0: "22GB", 1: "22GB"}, + use_cache=True, # Enable KV cache for faster generation + pad_token_id=self.llm_tokenizer.pad_token_id + ) + + logger.info("✓ Mistral-7B loaded") + + # Initialize validator + self.validator = FastToxicValidator() + logger.info("✓ Fast validator initialized") + + def generate_prompt(self, seed_text: str, label: str) -> str: + """Generate a prompt for a single toxicity type""" + base_prompt = """[INST] You are tasked with generating a highly {type} comment. Your goal is to create content that would definitely be flagged as {type}. + +Example for reference (generate something different): "{seed_text}" + +Requirements: +- Must be extremely {type} in nature +- Use very strong language +- Must be different from the example + +Generate ONLY the {type} comment (no explanations): [/INST]""" + + type_mapping = { + 'toxic': 'toxic and offensive', + 'severe_toxic': 'toxic and hostile', + 'obscene': 'obscene and vulgar', + 'threat': 'threatening and intimidating', + 'insult': 'insulting and demeaning', + 'identity_hate': 'discriminatory and hateful' + } + + return base_prompt.format(type=type_mapping[label], seed_text=seed_text) + + def flush_buffer(self): + """Flush the generation buffer to disk""" + if self.generation_buffer: + try: + with open(log_file, 'a', encoding='utf-8') as f: + for entry in self.generation_buffer: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + self.generation_buffer = [] + except Exception as e: + logger.error(f"Failed to flush buffer: {str(e)}") + + def log_generation(self, seed_text: str, prompt: str, generated_text: str, validation_results: Dict[str, bool]): + """Buffer log generation details with proper JSON serialization""" + # Convert numpy/torch boolean values to Python booleans + serializable_results = { + k: bool(v) for k, v in validation_results.items() + } + + log_entry = { + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "seed_text": seed_text, + "prompt": prompt, + "generated_text": generated_text, + "validation_results": serializable_results + } + + self.generation_buffer.append(log_entry) + + # Flush buffer if it reaches the size limit + if len(self.generation_buffer) >= self.buffer_size: + self.flush_buffer() + + def validate_sample(self, text: str, label: str, attempts: int) -> bool: + """Validate a generated sample for a single label""" + try: + # Get probability for the label + prob = self.validator.get_probabilities([text], label)[0] + + # Super lenient thresholds to get samples + thresholds = { + 'toxic': 0.3, # Very lenient to get started + 'severe_toxic': 0.2, + 'obscene': 0.2, + 'threat': 0.2, + 'insult': 0.2, + 'identity_hate': 0.2 + } + threshold = thresholds.get(label, 0.2) + + # Log validation result + passed = prob >= threshold + if passed or (attempts % 5 == 0): # More frequent logging + logger.info(f"\nValidation - Label: {label}, Text: {text}") + logger.info(f"Probability: {prob:.3f}, Threshold: {threshold:.2f}, Passed: {passed}") + + return passed + + except Exception as e: + logger.error(f"Validation error: {str(e)}") + return False + + def generate_samples(self, target_samples: int, label: str, + seed_texts: List[str], total_timeout: int = 300) -> pd.DataFrame: + """Generate samples for a single label with timeouts""" + start_time = time.time() + generated_samples = [] + attempts = 0 + max_attempts = target_samples * 50 # Much more attempts allowed + batch_size = min(16, target_samples) # Smaller batch size for better control + + pbar = tqdm(total=target_samples, desc=f"Generating {label} samples") + + try: + while len(generated_samples) < target_samples and attempts < max_attempts: + # Check timeout + if time.time() - start_time > total_timeout: + logger.warning(f"Generation timed out after {total_timeout} seconds") + break + + attempts += 1 + + # Select random seed text and generate prompt + seed_text = random.choice(seed_texts) + prompt = self.generate_prompt(seed_text, label) + + try: + # Generate text with optimized parameters + inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=True, + truncation=True, max_length=512).to(self.llm.device) + + with torch.no_grad(): + outputs = self.llm.generate( + **inputs, + max_new_tokens=200, # Doubled for longer content + num_beams=4, # Added beam search + temperature=1.35, # Higher temperature for more randomness + do_sample=True, + top_p=0.99, # Almost no filtering + top_k=200, # More options + num_return_sequences=1, + repetition_penalty=1.0, # No repetition penalty + no_repeat_ngram_size=0, # No ngram blocking + early_stopping=True, # Stop when complete + pad_token_id=self.llm_tokenizer.pad_token_id, + bos_token_id=self.llm_tokenizer.bos_token_id, + eos_token_id=self.llm_tokenizer.eos_token_id, + use_cache=True + ) + + text = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract the generated text after [/INST] + if "[/INST]" in text: + output = text.split("[/INST]")[1].strip() + output = output.strip().strip('"').strip("'") + + # Only check minimum length + if len(output) >= 10: + # Log generation attempt + if attempts % 5 == 0: # More frequent logging + logger.info(f"\nAttempt {attempts}: Generated text: {output}") + + # Validate sample + if self.validate_sample(output, label, attempts): + sample_dict = {'comment_text': output} + for l in ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']: + sample_dict[l] = 1 if l == label else 0 + generated_samples.append(sample_dict) + pbar.update(1) + logger.info(f"✓ Valid {label} sample generated ({len(generated_samples)}/{target_samples})") + + except Exception as e: + logger.error(f"Generation error on attempt {attempts}: {str(e)}") + continue + + # Clear cache less frequently + if attempts % 200 == 0: + torch.cuda.empty_cache() + gc.collect() + + finally: + pbar.close() + logger.info(f"Generation finished: {len(generated_samples)}/{target_samples} samples in {attempts} attempts") + + # Return results even if partial + if generated_samples: + return pd.DataFrame(generated_samples) + return None + + def augment_dataset(self, target_samples: int, label: str, seed_texts: List[str], timeout_minutes: int = 5) -> pd.DataFrame: + """Generate a specific number of samples with given label combination""" + logger.info(f"\nGenerating {target_samples} samples with label: {label}") + + generated_samples = [] + batch_size = min(32, target_samples) + start_time = time.time() + timeout_seconds = min(timeout_minutes * 60, 300) # Hard limit of 5 minutes + total_generated = 0 + pbar = None + + try: + # Create progress bar + pbar = tqdm( + total=target_samples, + desc="Generating", + unit="samples", + ncols=100, + bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' + ) + + while total_generated < target_samples: + # Check timeout + elapsed_time = time.time() - start_time + if elapsed_time > timeout_seconds: + logger.warning(f"Time limit reached after {elapsed_time/60:.1f} minutes") + break + + # Calculate remaining samples needed + remaining = target_samples - total_generated + current_batch_size = min(batch_size, remaining) + + # Select batch of seed texts + batch_seeds = np.random.choice(seed_texts, size=current_batch_size) + prompts = [self.generate_prompt(seed, label) for seed in batch_seeds] + + # Generate and validate samples + batch_start = time.time() + new_samples = self.generate_samples( + target_samples=current_batch_size, + label=label, + seed_texts=batch_seeds, + total_timeout=timeout_seconds - elapsed_time + ) + + if new_samples is not None and not new_samples.empty: + if len(new_samples) > remaining: + new_samples = new_samples.head(remaining) + + generated_samples.append(new_samples) + num_new = len(new_samples) + total_generated += num_new + + # Update progress bar + pbar.update(num_new) + + # Calculate and display metrics + elapsed_minutes = elapsed_time / 60 + rate = total_generated / elapsed_minutes if elapsed_minutes > 0 else 0 + batch_time = time.time() - batch_start + time_remaining = max(0, timeout_seconds - elapsed_time) + + pbar.set_postfix({ + 'rate': f'{rate:.1f}/min', + 'batch': f'{batch_time:.1f}s', + 'remain': f'{time_remaining:.0f}s' + }, refresh=True) + + # Memory management every few batches + if total_generated % (batch_size * 4) == 0: + torch.cuda.empty_cache() + + # Combine all generated samples + if generated_samples: + final_df = pd.concat(generated_samples, ignore_index=True) + if len(final_df) > target_samples: + final_df = final_df.head(target_samples) + logger.info(f"Successfully generated {len(final_df)} samples in {elapsed_time/60:.1f} minutes") + return final_df + + return None + + except Exception as e: + logger.error(f"Generation error: {str(e)}") + return None + finally: + if pbar is not None: + pbar.close() + # Final cleanup + self.flush_buffer() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/datacard.md b/datacard.md new file mode 100644 index 0000000000000000000000000000000000000000..ac1602c87b9fa71b1cad0585924ca70d607aab40 --- /dev/null +++ b/datacard.md @@ -0,0 +1,39 @@ +# Jigsaw Toxic Comment Classification Dataset + +## Overview +Version: 1.0 +Date Created: 2025-02-03 + +### Description + + The Jigsaw Toxic Comment Classification Dataset is designed to help identify and classify toxic online comments. + It contains text comments with multiple toxicity-related labels including general toxicity, severe toxicity, + obscenity, threats, insults, and identity-based hate speech. + + The dataset includes: + 1. Main training data with binary toxicity labels + 2. Unintended bias training data with additional identity attributes + 3. Processed versions with sequence length 128 for direct model input + 4. Test and validation sets for model evaluation + + This dataset was created by Jigsaw and Google's Conversation AI team to help improve online conversation quality + by identifying and classifying various forms of toxic comments. + + +## Column Descriptions + +- **id**: Unique identifier for each comment +- **comment_text**: The text content of the comment to be classified +- **toxic**: Binary label indicating if the comment is toxic +- **severe_toxic**: Binary label for extremely toxic comments +- **obscene**: Binary label for obscene content +- **threat**: Binary label for threatening content +- **insult**: Binary label for insulting content +- **identity_hate**: Binary label for identity-based hate speech +- **target**: Overall toxicity score (in bias dataset) +- **identity_attack**: Binary label for identity-based attacks +- **identity_***: Various identity-related attributes in the bias dataset +- **lang**: Language of the comment + +## Files + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..9f1d722b68c461f3481238a851b32e20fd53800a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3.8' + +services: + toxic-classifier: + build: . + runtime: nvidia # Enable NVIDIA runtime for GPU support + environment: + - NVIDIA_VISIBLE_DEVICES=all + - WANDB_API_KEY=${WANDB_API_KEY} # Set this in .env file + volumes: + - ./dataset:/app/dataset # Mount dataset directory + - ./weights:/app/weights # Mount weights directory + command: python model/train.py # Default command, can be overridden \ No newline at end of file diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_identity_hate.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_identity_hate.png new file mode 100644 index 0000000000000000000000000000000000000000..b7de301074486bff7e5d44d9a23e2afdc4c5a925 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_identity_hate.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_insult.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_insult.png new file mode 100644 index 0000000000000000000000000000000000000000..228c23926d820cb4d7ea8a2ae56c52ccc2d18636 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_insult.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_obscene.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_obscene.png new file mode 100644 index 0000000000000000000000000000000000000000..ac1a636a891dde3b4465a14a5f47948644399f0b Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_obscene.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_severe_toxic.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_severe_toxic.png new file mode 100644 index 0000000000000000000000000000000000000000..1372f56aaf537b81fd3972f9abd3424b0f2ea755 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_severe_toxic.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_threat.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_threat.png new file mode 100644 index 0000000000000000000000000000000000000000..5cb73b6e800c4cd2f2ec9f76c6d8fe05b2f04fbc Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_threat.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic.png new file mode 100644 index 0000000000000000000000000000000000000000..5986ca310cb3437f8696063dd002b512ee17b415 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_0.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_0.png new file mode 100644 index 0000000000000000000000000000000000000000..79bd442ac84d81cdc8564337e18a973c85872beb Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_0.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_1.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_1.png new file mode 100644 index 0000000000000000000000000000000000000000..60f85a418f83d79bc844c6f816f8c0b88c9b90dc Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_1.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_2.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_2.png new file mode 100644 index 0000000000000000000000000000000000000000..5cf9f1732a6f4b5ea68f143ac0e12563476298b6 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_2.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_3.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_3.png new file mode 100644 index 0000000000000000000000000000000000000000..12a4b3cacef5b14ed58722f916d4b2c46c1091dd Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_3.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_4.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_4.png new file mode 100644 index 0000000000000000000000000000000000000000..08a2623f3076d135eb512eb4b176d8638724d494 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_4.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_5.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_5.png new file mode 100644 index 0000000000000000000000000000000000000000..5f86526ba02d8e0f5ed3a8efe3ab06b321996296 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_5.png differ diff --git a/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_6.png b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_6.png new file mode 100644 index 0000000000000000000000000000000000000000..d134c02d9d4a65212d3e8565d19c4053a70daf82 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/confusion_matrices/cm_toxic_6.png differ diff --git a/evaluation_results/eval_20250208_161149/eval_params.json b/evaluation_results/eval_20250208_161149/eval_params.json new file mode 100644 index 0000000000000000000000000000000000000000..013692c8ef914f9b7e1aaef466cff4faff949332 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/eval_params.json @@ -0,0 +1,7 @@ +{ + "timestamp": "20250208_161149", + "model_path": "weights/toxic_classifier_xlm-roberta-large", + "test_file": "dataset/split/test.csv", + "batch_size": 32, + "num_workers": null +} \ No newline at end of file diff --git a/evaluation_results/eval_20250208_161149/evaluation_results.json b/evaluation_results/eval_20250208_161149/evaluation_results.json new file mode 100644 index 0000000000000000000000000000000000000000..8b820c62e944cf377020eedaebc76b9fe0e4966d --- /dev/null +++ b/evaluation_results/eval_20250208_161149/evaluation_results.json @@ -0,0 +1,2020 @@ +{ + "overall": { + "loss": 0.18776385083473274, + "auc_macro": 0.9259171799699759, + "auc_weighted": 0.9442696333538418, + "precision_macro": 0.4388604553772207, + "precision_weighted": 0.7008073672218381, + "recall_macro": 0.8836014181101747, + "recall_weighted": 0.9051010634378761, + "f1_macro": 0.530782857064369, + "f1_weighted": 0.7669279374035199, + "class_support": { + "toxic": 17646, + "severe_toxic": 1649, + "obscene": 8625, + "threat": 714, + "insult": 10201, + "identity_hate": 1882 + }, + "per_class_metrics": { + "toxic": { + "precision": 0.9115322083309974, + "recall": 0.9213986172503683, + "f1": 0.9164388580446975, + "support": 17646, + "specificity": 0.9121478677207437 + }, + "severe_toxic": { + "precision": 0.15755900489049543, + "recall": 0.8987265009096422, + "f1": 0.26811397557666217, + "support": 1649, + "specificity": 0.7666597956359139 + }, + "obscene": { + "precision": 0.6238325281803543, + "recall": 0.8983188405797101, + "f1": 0.7363269185079592, + "support": 8625, + "specificity": 0.8268539450765297 + }, + "threat": { + "precision": 0.10505486598309048, + "recall": 0.8179271708683473, + "f1": 0.18619480312450185, + "support": 714, + "specificity": 0.8574253453315757 + }, + "insult": { + "precision": 0.6205890336590663, + "recall": 0.8964807371826291, + "f1": 0.7334482896900189, + "support": 10201, + "specificity": 0.7799425355217067 + }, + "identity_hate": { + "precision": 0.21459509121932013, + "recall": 0.8687566418703507, + "f1": 0.3441742974423745, + "support": 1882, + "specificity": 0.822570123939987 + } + }, + "class_weights": { + "toxic": 0.43338163420684234, + "severe_toxic": 0.04049905444900165, + "obscene": 0.21182798339759806, + "threat": 0.017535673060392463, + "insult": 0.2505341749146548, + "identity_hate": 0.04622147997151067 + }, + "hamming_loss": 0.1618924586235303, + "exact_match": 0.499747247809481, + "specificity_macro": 0.8275999355377427, + "specificity_weighted": 0.8275999355377428, + "summary": { + "auc": { + "macro": 0.9259171799699759, + "weighted": 0.9442696333538418 + }, + "f1": { + "macro": 0.530782857064369, + "weighted": 0.7669279374035199 + }, + "precision": { + "macro": 0.4388604553772207, + "weighted": 0.7008073672218381 + }, + "recall": { + "macro": 0.8836014181101747, + "weighted": 0.9051010634378761 + }, + "specificity": { + "macro": 0.8275999355377427, + "weighted": 0.8275999355377428 + }, + "other_metrics": { + "hamming_loss": 0.1618924586235303, + "exact_match": 0.499747247809481 + }, + "class_support": { + "toxic": 17646, + "severe_toxic": 1649, + "obscene": 8625, + "threat": 714, + "insult": 10201, + "identity_hate": 1882 + } + } + }, + "per_language": { + "0": { + "auc": 0.9546775894690953, + "precision": 0.714413481020392, + "recall": 0.9246670642019479, + "f1": 0.7877150106257862, + "hamming_loss": 0.12826939843068874, + "exact_match": 0.5564516129032258, + "specificity": 0.8596476657420098, + "class_metrics": { + "toxic": { + "auc": 0.9621138334064959, + "threshold": 0.46047261357307434, + "precision": 0.8825137733163603, + "recall": 0.9342830882352909, + "f1": 0.9076608519017388, + "specificity": 0.8756218905472631, + "npv": 0.9301878222768437, + "positive_samples": 2176, + "true_positives": 2143, + "false_positives": 285, + "true_negatives": 2008, + "false_negatives": 150, + "auc_ci": [ + 0.9621138334064959, + 0.9621138334064959 + ], + "precision_ci": [ + 0.8825137733163603, + 0.8825137733163603 + ], + "recall_ci": [ + 0.9342830882352909, + 0.9342830882352909 + ], + "f1_ci": [ + 0.9076608519017388, + 0.9076608519017388 + ], + "specificity_ci": [ + 0.8756218905472631, + 0.8756218905472631 + ], + "npv_ci": [ + 0.9301878222768437, + 0.9301878222768437 + ], + "class_weights": { + "0.0": 0.951077943615257, + "1.0": 1.0542279411764706 + } + }, + "severe_toxic": { + "auc": 0.9499761279127715, + "threshold": 0.03537772223353386, + "precision": 0.8608043862269837, + "recall": 0.9492385786802037, + "f1": 0.9028611452277716, + "specificity": 0.8465042131632855, + "npv": 0.9434265401805545, + "positive_samples": 197, + "true_positives": 2177, + "false_positives": 352, + "true_negatives": 1941, + "false_negatives": 116, + "auc_ci": [ + 0.9499761279127715, + 0.9499761279127715 + ], + "precision_ci": [ + 0.8608043862269837, + 0.8608043862269837 + ], + "recall_ci": [ + 0.9492385786802037, + 0.9492385786802037 + ], + "f1_ci": [ + 0.9028611452277716, + 0.9028611452277716 + ], + "specificity_ci": [ + 0.8465042131632855, + 0.8465042131632855 + ], + "npv_ci": [ + 0.9434265401805545, + 0.9434265401805545 + ], + "class_weights": { + "0.0": 0.5224322477795491, + "1.0": 11.644670050761421 + } + }, + "obscene": { + "auc": 0.9572805958351019, + "threshold": 0.2777131497859955, + "precision": 0.8724828332798461, + "recall": 0.9115977291159771, + "f1": 0.8916114958872817, + "specificity": 0.8667660208643849, + "npv": 0.9074484866722257, + "positive_samples": 1233, + "true_positives": 2091, + "false_positives": 305, + "true_negatives": 1988, + "false_negatives": 202, + "auc_ci": [ + 0.9572805958351019, + 0.9572805958351019 + ], + "precision_ci": [ + 0.8724828332798461, + 0.8724828332798461 + ], + "recall_ci": [ + 0.9115977291159771, + 0.9115977291159771 + ], + "f1_ci": [ + 0.8916114958872817, + 0.8916114958872817 + ], + "specificity_ci": [ + 0.8667660208643849, + 0.8667660208643849 + ], + "npv_ci": [ + 0.9074484866722257, + 0.9074484866722257 + ], + "class_weights": { + "0.0": 0.6837555886736214, + "1.0": 1.8605028386050284 + } + }, + "threat": { + "auc": 0.9697358146798531, + "threshold": 0.016539234668016434, + "precision": 0.9045252081854022, + "recall": 0.9117647058823535, + "f1": 0.9081305291811165, + "specificity": 0.9037610619468958, + "npv": 0.9110528041980915, + "positive_samples": 68, + "true_positives": 2091, + "false_positives": 220, + "true_negatives": 2073, + "false_negatives": 202, + "auc_ci": [ + 0.9697358146798531, + 0.9697358146798531 + ], + "precision_ci": [ + 0.9045252081854022, + 0.9045252081854022 + ], + "recall_ci": [ + 0.9117647058823535, + 0.9117647058823535 + ], + "f1_ci": [ + 0.9081305291811165, + 0.9081305291811165 + ], + "specificity_ci": [ + 0.9037610619468958, + 0.9037610619468958 + ], + "npv_ci": [ + 0.9110528041980915, + 0.9110528041980915 + ], + "class_weights": { + "0.0": 0.5075221238938054, + "1.0": 33.73529411764706 + } + }, + "insult": { + "auc": 0.935014291573492, + "threshold": 0.25907590985298157, + "precision": 0.833978890287596, + "recall": 0.9098862642169729, + "f1": 0.8702805202104968, + "specificity": 0.8188679245282912, + "npv": 0.900862976980011, + "positive_samples": 1143, + "true_positives": 2087, + "false_positives": 415, + "true_negatives": 1878, + "false_negatives": 206, + "auc_ci": [ + 0.935014291573492, + 0.935014291573492 + ], + "precision_ci": [ + 0.833978890287596, + 0.833978890287596 + ], + "recall_ci": [ + 0.9098862642169729, + 0.9098862642169729 + ], + "f1_ci": [ + 0.8702805202104968, + 0.8702805202104968 + ], + "specificity_ci": [ + 0.8188679245282912, + 0.8188679245282912 + ], + "npv_ci": [ + 0.900862976980011, + 0.900862976980011 + ], + "class_weights": { + "0.0": 0.6658925979680697, + "1.0": 2.0069991251093615 + } + }, + "identity_hate": { + "auc": 0.9686336850292078, + "threshold": 0.026042653247714043, + "precision": 0.8623651962191886, + "recall": 0.9626168224299065, + "f1": 0.909737451082551, + "specificity": 0.8463648834019236, + "npv": 0.9576992819322562, + "positive_samples": 214, + "true_positives": 2208, + "false_positives": 352, + "true_negatives": 1941, + "false_negatives": 85, + "auc_ci": [ + 0.9686336850292078, + 0.9686336850292078 + ], + "precision_ci": [ + 0.8623651962191886, + 0.8623651962191886 + ], + "recall_ci": [ + 0.9626168224299065, + 0.9626168224299065 + ], + "f1_ci": [ + 0.909737451082551, + 0.909737451082551 + ], + "specificity_ci": [ + 0.8463648834019236, + 0.8463648834019236 + ], + "npv_ci": [ + 0.9576992819322562, + 0.9576992819322562 + ], + "class_weights": { + "0.0": 0.5244627343392776, + "1.0": 10.719626168224298 + } + } + }, + "sample_count": 4588 + }, + "1": { + "auc": 0.9420109561343032, + "precision": 0.7054445371054338, + "recall": 0.8937771830043493, + "f1": 0.7655260008199765, + "hamming_loss": 0.16467680852429553, + "exact_match": 0.49354900828037745, + "specificity": 0.8275039240639036, + "class_metrics": { + "toxic": { + "auc": 0.970066021237747, + "threshold": 0.44148319959640503, + "precision": 0.9051201281749973, + "recall": 0.916216216216217, + "f1": 0.910634371966946, + "specificity": 0.903956972723781, + "npv": 0.9151763423430814, + "positive_samples": 2590, + "true_positives": 2378, + "false_positives": 249, + "true_negatives": 2347, + "false_negatives": 217, + "auc_ci": [ + 0.970066021237747, + 0.970066021237747 + ], + "precision_ci": [ + 0.9051201281749973, + 0.9051201281749973 + ], + "recall_ci": [ + 0.916216216216217, + 0.916216216216217 + ], + "f1_ci": [ + 0.910634371966946, + 0.910634371966946 + ], + "specificity_ci": [ + 0.903956972723781, + 0.903956972723781 + ], + "npv_ci": [ + 0.9151763423430814, + 0.9151763423430814 + ], + "class_weights": { + "0.0": 0.9975028812908183, + "1.0": 1.0025096525096524 + } + }, + "severe_toxic": { + "auc": 0.9032119421376688, + "threshold": 0.03648429363965988, + "precision": 0.8147008122253235, + "recall": 0.8688524590163955, + "f1": 0.8409057392553343, + "specificity": 0.8023843200646473, + "npv": 0.8595146599106457, + "positive_samples": 244, + "true_positives": 2255, + "false_positives": 513, + "true_negatives": 2083, + "false_negatives": 340, + "auc_ci": [ + 0.9032119421376688, + 0.9032119421376688 + ], + "precision_ci": [ + 0.8147008122253235, + 0.8147008122253235 + ], + "recall_ci": [ + 0.8688524590163955, + 0.8688524590163955 + ], + "f1_ci": [ + 0.8409057392553343, + 0.8409057392553343 + ], + "specificity_ci": [ + 0.8023843200646473, + 0.8023843200646473 + ], + "npv_ci": [ + 0.8595146599106457, + 0.8595146599106457 + ], + "class_weights": { + "0.0": 0.5246514447363103, + "1.0": 10.64139344262295 + } + }, + "obscene": { + "auc": 0.9387485218400086, + "threshold": 0.1990610957145691, + "precision": 0.8573644543610149, + "recall": 0.8723747980614001, + "f1": 0.8648044977770555, + "specificity": 0.8548672566371623, + "npv": 0.8701005785595336, + "positive_samples": 1238, + "true_positives": 2265, + "false_positives": 376, + "true_negatives": 2219, + "false_negatives": 331, + "auc_ci": [ + 0.9387485218400086, + 0.9387485218400086 + ], + "precision_ci": [ + 0.8573644543610149, + 0.8573644543610149 + ], + "recall_ci": [ + 0.8723747980614001, + 0.8723747980614001 + ], + "f1_ci": [ + 0.8648044977770555, + 0.8648044977770555 + ], + "specificity_ci": [ + 0.8548672566371623, + 0.8548672566371623 + ], + "npv_ci": [ + 0.8701005785595336, + 0.8701005785595336 + ], + "class_weights": { + "0.0": 0.6565107458912769, + "1.0": 2.097334410339257 + } + }, + "threat": { + "auc": 0.930141945247047, + "threshold": 0.012619060464203358, + "precision": 0.8505847769217403, + "recall": 0.8773584905660369, + "f1": 0.8637642103418028, + "specificity": 0.8458816591311225, + "npv": 0.8733726632315268, + "positive_samples": 106, + "true_positives": 2278, + "false_positives": 400, + "true_negatives": 2196, + "false_negatives": 318, + "auc_ci": [ + 0.930141945247047, + 0.930141945247047 + ], + "precision_ci": [ + 0.8505847769217403, + 0.8505847769217403 + ], + "recall_ci": [ + 0.8773584905660369, + 0.8773584905660369 + ], + "f1_ci": [ + 0.8637642103418028, + 0.8637642103418028 + ], + "specificity_ci": [ + 0.8458816591311225, + 0.8458816591311225 + ], + "npv_ci": [ + 0.8733726632315268, + 0.8733726632315268 + ], + "class_weights": { + "0.0": 0.5104187143699627, + "1.0": 24.495283018867923 + } + }, + "insult": { + "auc": 0.9116567628368878, + "threshold": 0.24214455485343933, + "precision": 0.8063856025869378, + "recall": 0.8794466403162026, + "f1": 0.8413329522908936, + "specificity": 0.7888435374149729, + "npv": 0.8674359236672227, + "positive_samples": 1518, + "true_positives": 2283, + "false_positives": 548, + "true_negatives": 2048, + "false_negatives": 313, + "auc_ci": [ + 0.9116567628368878, + 0.9116567628368878 + ], + "precision_ci": [ + 0.8063856025869378, + 0.8063856025869378 + ], + "recall_ci": [ + 0.8794466403162026, + 0.8794466403162026 + ], + "f1_ci": [ + 0.8413329522908936, + 0.8413329522908936 + ], + "specificity_ci": [ + 0.7888435374149729, + 0.7888435374149729 + ], + "npv_ci": [ + 0.8674359236672227, + 0.8674359236672227 + ], + "class_weights": { + "0.0": 0.706530612244898, + "1.0": 1.7104743083003953 + } + }, + "identity_hate": { + "auc": 0.9000925697269513, + "threshold": 0.03167847916483879, + "precision": 0.7933569321076599, + "recall": 0.8865248226950354, + "f1": 0.8373572860825882, + "specificity": 0.7690897984117396, + "npv": 0.8714256962068888, + "positive_samples": 282, + "true_positives": 2301, + "false_positives": 599, + "true_negatives": 1996, + "false_negatives": 294, + "auc_ci": [ + 0.9000925697269513, + 0.9000925697269513 + ], + "precision_ci": [ + 0.7933569321076599, + 0.7933569321076599 + ], + "recall_ci": [ + 0.8865248226950354, + 0.8865248226950354 + ], + "f1_ci": [ + 0.8373572860825882, + 0.8373572860825882 + ], + "specificity_ci": [ + 0.7690897984117396, + 0.7690897984117396 + ], + "npv_ci": [ + 0.8714256962068888, + 0.8714256962068888 + ], + "class_weights": { + "0.0": 0.5287110568112401, + "1.0": 9.207446808510639 + } + } + }, + "sample_count": 5193 + }, + "2": { + "auc": 0.9291857688264461, + "precision": 0.6563281876729908, + "recall": 0.9071871335232032, + "f1": 0.7348671832220326, + "hamming_loss": 0.20595261153076377, + "exact_match": 0.4263025372845245, + "specificity": 0.7733622212755961, + "class_metrics": { + "toxic": { + "auc": 0.962186696069825, + "threshold": 0.3978160321712494, + "precision": 0.8937958373522624, + "recall": 0.9136996904024615, + "f1": 0.9036381748465286, + "specificity": 0.8914307871267977, + "npv": 0.9117341057406776, + "positive_samples": 2584, + "true_positives": 2358, + "false_positives": 280, + "true_negatives": 2301, + "false_negatives": 222, + "auc_ci": [ + 0.962186696069825, + 0.962186696069825 + ], + "precision_ci": [ + 0.8937958373522624, + 0.8937958373522624 + ], + "recall_ci": [ + 0.9136996904024615, + 0.9136996904024615 + ], + "f1_ci": [ + 0.9036381748465286, + 0.9036381748465286 + ], + "specificity_ci": [ + 0.8914307871267977, + 0.8914307871267977 + ], + "npv_ci": [ + 0.9117341057406776, + 0.9117341057406776 + ], + "class_weights": { + "0.0": 1.0009693679720821, + "1.0": 0.9990325077399381 + } + }, + "severe_toxic": { + "auc": 0.890519864426667, + "threshold": 0.015000982210040092, + "precision": 0.7460680730510791, + "recall": 0.918032786885247, + "f1": 0.8231651924456013, + "specificity": 0.6875381175035498, + "npv": 0.8934806428840502, + "positive_samples": 244, + "true_positives": 2369, + "false_positives": 806, + "true_negatives": 1774, + "false_negatives": 211, + "auc_ci": [ + 0.890519864426667, + 0.890519864426667 + ], + "precision_ci": [ + 0.7460680730510791, + 0.7460680730510791 + ], + "recall_ci": [ + 0.918032786885247, + 0.918032786885247 + ], + "f1_ci": [ + 0.8231651924456013, + 0.8231651924456013 + ], + "specificity_ci": [ + 0.6875381175035498, + 0.6875381175035498 + ], + "npv_ci": [ + 0.8934806428840502, + 0.8934806428840502 + ], + "class_weights": { + "0.0": 0.5248017889815003, + "1.0": 10.579918032786885 + } + }, + "obscene": { + "auc": 0.9233059279915251, + "threshold": 0.11362762749195099, + "precision": 0.7873800414823968, + "recall": 0.9095315024232634, + "f1": 0.8440592612850891, + "specificity": 0.7543949044586057, + "npv": 0.892919379205219, + "positive_samples": 1238, + "true_positives": 2347, + "false_positives": 634, + "true_negatives": 1947, + "false_negatives": 233, + "auc_ci": [ + 0.9233059279915251, + 0.9233059279915251 + ], + "precision_ci": [ + 0.7873800414823968, + 0.7873800414823968 + ], + "recall_ci": [ + 0.9095315024232634, + 0.9095315024232634 + ], + "f1_ci": [ + 0.8440592612850891, + 0.8440592612850891 + ], + "specificity_ci": [ + 0.7543949044586057, + 0.7543949044586057 + ], + "npv_ci": [ + 0.892919379205219, + 0.892919379205219 + ], + "class_weights": { + "0.0": 0.6577070063694268, + "1.0": 2.0852180936995155 + } + }, + "threat": { + "auc": 0.848578598380765, + "threshold": 0.008195769973099232, + "precision": 0.7785886139481758, + "recall": 0.8055555555555555, + "f1": 0.791842555156752, + "specificity": 0.7709198813056214, + "npv": 0.7985792107105536, + "positive_samples": 108, + "true_positives": 2079, + "false_positives": 591, + "true_negatives": 1990, + "false_negatives": 501, + "auc_ci": [ + 0.848578598380765, + 0.848578598380765 + ], + "precision_ci": [ + 0.7785886139481758, + 0.7785886139481758 + ], + "recall_ci": [ + 0.8055555555555555, + 0.8055555555555555 + ], + "f1_ci": [ + 0.791842555156752, + 0.791842555156752 + ], + "specificity_ci": [ + 0.7709198813056214, + 0.7709198813056214 + ], + "npv_ci": [ + 0.7985792107105536, + 0.7985792107105536 + ], + "class_weights": { + "0.0": 0.5106824925816024, + "1.0": 23.90277777777778 + } + }, + "insult": { + "auc": 0.8943137096607889, + "threshold": 0.1587354838848114, + "precision": 0.7484673378377763, + "recall": 0.9141347424042362, + "f1": 0.8230472043830551, + "specificity": 0.6927925459029957, + "npv": 0.889726581805318, + "positive_samples": 1514, + "true_positives": 2359, + "false_positives": 793, + "true_negatives": 1788, + "false_negatives": 221, + "auc_ci": [ + 0.8943137096607889, + 0.8943137096607889 + ], + "precision_ci": [ + 0.7484673378377763, + 0.7484673378377763 + ], + "recall_ci": [ + 0.9141347424042362, + 0.9141347424042362 + ], + "f1_ci": [ + 0.8230472043830551, + 0.8230472043830551 + ], + "specificity_ci": [ + 0.6927925459029957, + 0.6927925459029957 + ], + "npv_ci": [ + 0.889726581805318, + 0.889726581805318 + ], + "class_weights": { + "0.0": 0.7074540970128802, + "1.0": 1.7050858652575958 + } + }, + "identity_hate": { + "auc": 0.9040654827596841, + "threshold": 0.0467526838183403, + "precision": 0.8408828817107497, + "recall": 0.8291814946619218, + "f1": 0.8349911950184066, + "specificity": 0.8430970913560043, + "npv": 0.8315259121222329, + "positive_samples": 281, + "true_positives": 2140, + "false_positives": 405, + "true_negatives": 2176, + "false_negatives": 440, + "auc_ci": [ + 0.9040654827596841, + 0.9040654827596841 + ], + "precision_ci": [ + 0.8408828817107497, + 0.8408828817107497 + ], + "recall_ci": [ + 0.8291814946619218, + 0.8291814946619218 + ], + "f1_ci": [ + 0.8349911950184066, + 0.8349911950184066 + ], + "specificity_ci": [ + 0.8430970913560043, + 0.8430970913560043 + ], + "npv_ci": [ + 0.8315259121222329, + 0.8315259121222329 + ], + "class_weights": { + "0.0": 0.5287791888570258, + "1.0": 9.186832740213523 + } + } + }, + "sample_count": 5163 + }, + "3": { + "auc": 0.9472472410532857, + "precision": 0.6982701786686969, + "recall": 0.9152656355077337, + "f1": 0.7674148586410611, + "hamming_loss": 0.1731811145510836, + "exact_match": 0.48471362229102166, + "specificity": 0.8133241121366614, + "class_metrics": { + "toxic": { + "auc": 0.9747483574660619, + "threshold": 0.5033379793167114, + "precision": 0.9204374197691823, + "recall": 0.9294300116324036, + "f1": 0.9249118582673775, + "specificity": 0.9196601004248757, + "npv": 0.9287337466652424, + "positive_samples": 2579, + "true_positives": 2401, + "false_positives": 207, + "true_negatives": 2376, + "false_negatives": 182, + "auc_ci": [ + 0.9747483574660619, + 0.9747483574660619 + ], + "precision_ci": [ + 0.9204374197691823, + 0.9204374197691823 + ], + "recall_ci": [ + 0.9294300116324036, + 0.9294300116324036 + ], + "f1_ci": [ + 0.9249118582673775, + 0.9249118582673775 + ], + "specificity_ci": [ + 0.9196601004248757, + 0.9196601004248757 + ], + "npv_ci": [ + 0.9287337466652424, + 0.9287337466652424 + ], + "class_weights": { + "0.0": 0.9980687524140595, + "1.0": 1.0019387359441645 + } + }, + "severe_toxic": { + "auc": 0.9073687265747961, + "threshold": 0.021415209397673607, + "precision": 0.7618540559183846, + "recall": 0.93388429752066, + "f1": 0.8391430651806406, + "specificity": 0.7080795777506993, + "npv": 0.9146007419992344, + "positive_samples": 242, + "true_positives": 2413, + "false_positives": 754, + "true_negatives": 1829, + "false_negatives": 170, + "auc_ci": [ + 0.9073687265747961, + 0.9073687265747961 + ], + "precision_ci": [ + 0.7618540559183846, + 0.7618540559183846 + ], + "recall_ci": [ + 0.93388429752066, + 0.93388429752066 + ], + "f1_ci": [ + 0.8391430651806406, + 0.8391430651806406 + ], + "specificity_ci": [ + 0.7080795777506993, + 0.7080795777506993 + ], + "npv_ci": [ + 0.9146007419992344, + 0.9146007419992344 + ], + "class_weights": { + "0.0": 0.5245635403978888, + "1.0": 10.677685950413224 + } + }, + "obscene": { + "auc": 0.9429228614622618, + "threshold": 0.14896434545516968, + "precision": 0.822101549733319, + "recall": 0.9148418491484125, + "f1": 0.8659958665665364, + "specificity": 0.8020330368488026, + "npv": 0.9040137548341648, + "positive_samples": 1233, + "true_positives": 2363, + "false_positives": 511, + "true_negatives": 2072, + "false_negatives": 220, + "auc_ci": [ + 0.9429228614622618, + 0.9429228614622618 + ], + "precision_ci": [ + 0.822101549733319, + 0.822101549733319 + ], + "recall_ci": [ + 0.9148418491484125, + 0.9148418491484125 + ], + "f1_ci": [ + 0.8659958665665364, + 0.8659958665665364 + ], + "specificity_ci": [ + 0.8020330368488026, + 0.8020330368488026 + ], + "npv_ci": [ + 0.9040137548341648, + 0.9040137548341648 + ], + "class_weights": { + "0.0": 0.6566709021601016, + "1.0": 2.095701540957015 + } + }, + "threat": { + "auc": 0.8985232762406729, + "threshold": 0.013273251242935658, + "precision": 0.8299773755655987, + "recall": 0.8055555555555544, + "f1": 0.8175841319366995, + "specificity": 0.8349802371541444, + "npv": 0.8111134812286639, + "positive_samples": 108, + "true_positives": 2081, + "false_positives": 426, + "true_negatives": 2157, + "false_negatives": 502, + "auc_ci": [ + 0.8985232762406729, + 0.8985232762406729 + ], + "precision_ci": [ + 0.8299773755655987, + 0.8299773755655987 + ], + "recall_ci": [ + 0.8055555555555544, + 0.8055555555555544 + ], + "f1_ci": [ + 0.8175841319366995, + 0.8175841319366995 + ], + "specificity_ci": [ + 0.8349802371541444, + 0.8349802371541444 + ], + "npv_ci": [ + 0.8111134812286639, + 0.8111134812286639 + ], + "class_weights": { + "0.0": 0.5106719367588933, + "1.0": 23.925925925925927 + } + }, + "insult": { + "auc": 0.9178884966596437, + "threshold": 0.22368550300598145, + "precision": 0.8017937840347082, + "recall": 0.9065606361828928, + "f1": 0.8509647346472855, + "specificity": 0.7758950532932412, + "npv": 0.8925162032262658, + "positive_samples": 1509, + "true_positives": 2342, + "false_positives": 579, + "true_negatives": 2004, + "false_negatives": 241, + "auc_ci": [ + 0.9178884966596437, + 0.9178884966596437 + ], + "precision_ci": [ + 0.8017937840347082, + 0.8017937840347082 + ], + "recall_ci": [ + 0.9065606361828928, + 0.9065606361828928 + ], + "f1_ci": [ + 0.8509647346472855, + 0.8509647346472855 + ], + "specificity_ci": [ + 0.7758950532932412, + 0.7758950532932412 + ], + "npv_ci": [ + 0.8925162032262658, + 0.8925162032262658 + ], + "class_weights": { + "0.0": 0.70620388084176, + "1.0": 1.7123923127899272 + } + }, + "identity_hate": { + "auc": 0.9242209406948756, + "threshold": 0.042373284697532654, + "precision": 0.8424336725093711, + "recall": 0.8592057761732879, + "f1": 0.8507370677416805, + "specificity": 0.839296667348186, + "npv": 0.8563457480377756, + "positive_samples": 277, + "true_positives": 2220, + "false_positives": 415, + "true_negatives": 2168, + "false_negatives": 363, + "auc_ci": [ + 0.9242209406948756, + 0.9242209406948756 + ], + "precision_ci": [ + 0.8424336725093711, + 0.8424336725093711 + ], + "recall_ci": [ + 0.8592057761732879, + 0.8592057761732879 + ], + "f1_ci": [ + 0.8507370677416805, + 0.8507370677416805 + ], + "specificity_ci": [ + 0.839296667348186, + 0.839296667348186 + ], + "npv_ci": [ + 0.8563457480377756, + 0.8563457480377756 + ], + "class_weights": { + "0.0": 0.5283173175219792, + "1.0": 9.328519855595667 + } + } + }, + "sample_count": 5168 + }, + "4": { + "auc": 0.9418392933687934, + "precision": 0.7019672150256779, + "recall": 0.9036673990197736, + "f1": 0.766375554274002, + "hamming_loss": 0.1651803024428073, + "exact_match": 0.4955409073284219, + "specificity": 0.8245338509682739, + "class_metrics": { + "toxic": { + "auc": 0.9718317503718501, + "threshold": 0.4544762372970581, + "precision": 0.9205380327767301, + "recall": 0.9217594394705978, + "f1": 0.9211483312394544, + "specificity": 0.9204325994592514, + "npv": 0.9216554888385321, + "positive_samples": 2569, + "true_positives": 2377, + "false_positives": 205, + "true_negatives": 2373, + "false_negatives": 201, + "auc_ci": [ + 0.9718317503718501, + 0.9718317503718501 + ], + "precision_ci": [ + 0.9205380327767301, + 0.9205380327767301 + ], + "recall_ci": [ + 0.9217594394705978, + 0.9217594394705978 + ], + "f1_ci": [ + 0.9211483312394544, + 0.9211483312394544 + ], + "specificity_ci": [ + 0.9204325994592514, + 0.9204325994592514 + ], + "npv_ci": [ + 0.9216554888385321, + 0.9216554888385321 + ], + "class_weights": { + "0.0": 0.9961375048281189, + "1.0": 1.003892565200467 + } + }, + "severe_toxic": { + "auc": 0.8962662667751142, + "threshold": 0.0307308342307806, + "precision": 0.7913182428501319, + "recall": 0.8458333333333329, + "f1": 0.8176681460830066, + "specificity": 0.7769418462789687, + "npv": 0.834426745622858, + "positive_samples": 240, + "true_positives": 2181, + "false_positives": 575, + "true_negatives": 2003, + "false_negatives": 397, + "auc_ci": [ + 0.8962662667751142, + 0.8962662667751142 + ], + "precision_ci": [ + 0.7913182428501319, + 0.7913182428501319 + ], + "recall_ci": [ + 0.8458333333333329, + 0.8458333333333329 + ], + "f1_ci": [ + 0.8176681460830066, + 0.8176681460830066 + ], + "specificity_ci": [ + 0.7769418462789687, + 0.7769418462789687 + ], + "npv_ci": [ + 0.834426745622858, + 0.834426745622858 + ], + "class_weights": { + "0.0": 0.5244001626677511, + "1.0": 10.745833333333334 + } + }, + "obscene": { + "auc": 0.9401245966951454, + "threshold": 0.1775909662246704, + "precision": 0.8495468615216861, + "recall": 0.8913398692810475, + "f1": 0.8699417085541208, + "specificity": 0.8421453990848948, + "npv": 0.8857178178787266, + "positive_samples": 1224, + "true_positives": 2298, + "false_positives": 407, + "true_negatives": 2171, + "false_negatives": 280, + "auc_ci": [ + 0.9401245966951454, + 0.9401245966951454 + ], + "precision_ci": [ + 0.8495468615216861, + 0.8495468615216861 + ], + "recall_ci": [ + 0.8913398692810475, + 0.8913398692810475 + ], + "f1_ci": [ + 0.8699417085541208, + 0.8699417085541208 + ], + "specificity_ci": [ + 0.8421453990848948, + 0.8421453990848948 + ], + "npv_ci": [ + 0.8857178178787266, + 0.8857178178787266 + ], + "class_weights": { + "0.0": 0.6555668530757499, + "1.0": 2.1070261437908497 + } + }, + "threat": { + "auc": 0.8861722579224652, + "threshold": 0.014509523287415504, + "precision": 0.841106024006686, + "recall": 0.7943925233644874, + "f1": 0.81708215259711, + "specificity": 0.8499307067907416, + "npv": 0.8052107636996033, + "positive_samples": 107, + "true_positives": 2048, + "false_positives": 387, + "true_negatives": 2191, + "false_negatives": 530, + "auc_ci": [ + 0.8861722579224652, + 0.8861722579224652 + ], + "precision_ci": [ + 0.841106024006686, + 0.841106024006686 + ], + "recall_ci": [ + 0.7943925233644874, + 0.7943925233644874 + ], + "f1_ci": [ + 0.81708215259711, + 0.81708215259711 + ], + "specificity_ci": [ + 0.8499307067907416, + 0.8499307067907416 + ], + "npv_ci": [ + 0.8052107636996033, + 0.8052107636996033 + ], + "class_weights": { + "0.0": 0.5105919619877252, + "1.0": 24.102803738317757 + } + }, + "insult": { + "auc": 0.908347099690273, + "threshold": 0.19917058944702148, + "precision": 0.787211545222267, + "recall": 0.9028609447771131, + "f1": 0.8410793781503274, + "specificity": 0.755950752393989, + "npv": 0.8861326740097348, + "positive_samples": 1503, + "true_positives": 2328, + "false_positives": 629, + "true_negatives": 1949, + "false_negatives": 250, + "auc_ci": [ + 0.908347099690273, + 0.908347099690273 + ], + "precision_ci": [ + 0.787211545222267, + 0.787211545222267 + ], + "recall_ci": [ + 0.9028609447771131, + 0.9028609447771131 + ], + "f1_ci": [ + 0.8410793781503274, + 0.8410793781503274 + ], + "specificity_ci": [ + 0.755950752393989, + 0.755950752393989 + ], + "npv_ci": [ + 0.8861326740097348, + 0.8861326740097348 + ], + "class_weights": { + "0.0": 0.7056087551299589, + "1.0": 1.7159015302727878 + } + }, + "identity_hate": { + "auc": 0.9136671508934288, + "threshold": 0.031982019543647766, + "precision": 0.8173388685191341, + "recall": 0.8868613138686137, + "f1": 0.8506820152960648, + "specificity": 0.801801801801802, + "npv": 0.8763431199913764, + "positive_samples": 274, + "true_positives": 2287, + "false_positives": 511, + "true_negatives": 2067, + "false_negatives": 291, + "auc_ci": [ + 0.9136671508934288, + 0.9136671508934288 + ], + "precision_ci": [ + 0.8173388685191341, + 0.8173388685191341 + ], + "recall_ci": [ + 0.8868613138686137, + 0.8868613138686137 + ], + "f1_ci": [ + 0.8506820152960648, + 0.8506820152960648 + ], + "specificity_ci": [ + 0.801801801801802, + 0.801801801801802 + ], + "npv_ci": [ + 0.8763431199913764, + 0.8763431199913764 + ], + "class_weights": { + "0.0": 0.528050778050778, + "1.0": 9.412408759124087 + } + } + }, + "sample_count": 5158 + }, + "5": { + "auc": 0.9460152147041221, + "precision": 0.7347347983801011, + "recall": 0.8867510548523206, + "f1": 0.7840490209789418, + "hamming_loss": 0.13677289804378806, + "exact_match": 0.5347842984842596, + "specificity": 0.8623489178772902, + "class_metrics": { + "toxic": { + "auc": 0.9757415342563065, + "threshold": 0.5313886404037476, + "precision": 0.9310023292772915, + "recall": 0.9121306376360682, + "f1": 0.9214698705828952, + "specificity": 0.9324009324009348, + "npv": 0.9138763886248709, + "positive_samples": 2572, + "true_positives": 2346, + "false_positives": 173, + "true_negatives": 2399, + "false_negatives": 226, + "auc_ci": [ + 0.9757415342563065, + 0.9757415342563065 + ], + "precision_ci": [ + 0.9310023292772915, + 0.9310023292772915 + ], + "recall_ci": [ + 0.9121306376360682, + 0.9121306376360682 + ], + "f1_ci": [ + 0.9214698705828952, + 0.9214698705828952 + ], + "specificity_ci": [ + 0.9324009324009348, + 0.9324009324009348 + ], + "npv_ci": [ + 0.9138763886248709, + 0.9138763886248709 + ], + "class_weights": { + "0.0": 0.9996114996114996, + "1.0": 1.0003888024883358 + } + }, + "severe_toxic": { + "auc": 0.9032281899714669, + "threshold": 0.05001964047551155, + "precision": 0.8240547826417868, + "recall": 0.8458333333333334, + "f1": 0.8348020409069885, + "specificity": 0.8194048104362093, + "npv": 0.8416483326674401, + "positive_samples": 240, + "true_positives": 2176, + "false_positives": 464, + "true_negatives": 2108, + "false_negatives": 396, + "auc_ci": [ + 0.9032281899714669, + 0.9032281899714669 + ], + "precision_ci": [ + 0.8240547826417868, + 0.8240547826417868 + ], + "recall_ci": [ + 0.8458333333333334, + 0.8458333333333334 + ], + "f1_ci": [ + 0.8348020409069885, + 0.8348020409069885 + ], + "specificity_ci": [ + 0.8194048104362093, + 0.8194048104362093 + ], + "npv_ci": [ + 0.8416483326674401, + 0.8416483326674401 + ], + "class_weights": { + "0.0": 0.5244598450876478, + "1.0": 10.720833333333333 + } + }, + "obscene": { + "auc": 0.9399297347094935, + "threshold": 0.20134443044662476, + "precision": 0.8638120606436712, + "recall": 0.8799999999999917, + "f1": 0.8718308933886383, + "specificity": 0.8612598826829971, + "npv": 0.8777082380338568, + "positive_samples": 1225, + "true_positives": 2264, + "false_positives": 356, + "true_negatives": 2216, + "false_negatives": 308, + "auc_ci": [ + 0.9399297347094935, + 0.9399297347094935 + ], + "precision_ci": [ + 0.8638120606436712, + 0.8638120606436712 + ], + "recall_ci": [ + 0.8799999999999917, + 0.8799999999999917 + ], + "f1_ci": [ + 0.8718308933886383, + 0.8718308933886383 + ], + "specificity_ci": [ + 0.8612598826829971, + 0.8612598826829971 + ], + "npv_ci": [ + 0.8777082380338568, + 0.8777082380338568 + ], + "class_weights": { + "0.0": 0.6562101504718184, + "1.0": 2.100408163265306 + } + }, + "threat": { + "auc": 0.8786647405643102, + "threshold": 0.018557138741016388, + "precision": 0.8659949024954022, + "recall": 0.8055555555555568, + "f1": 0.834682556458845, + "specificity": 0.8753473600635171, + "npv": 0.8182408543184921, + "positive_samples": 108, + "true_positives": 2072, + "false_positives": 320, + "true_negatives": 2252, + "false_negatives": 500, + "auc_ci": [ + 0.8786647405643102, + 0.8786647405643102 + ], + "precision_ci": [ + 0.8659949024954022, + 0.8659949024954022 + ], + "recall_ci": [ + 0.8055555555555568, + 0.8055555555555568 + ], + "f1_ci": [ + 0.834682556458845, + 0.834682556458845 + ], + "specificity_ci": [ + 0.8753473600635171, + 0.8753473600635171 + ], + "npv_ci": [ + 0.8182408543184921, + 0.8182408543184921 + ], + "class_weights": { + "0.0": 0.5107185391028186, + "1.0": 23.824074074074073 + } + }, + "insult": { + "auc": 0.9170891169219639, + "threshold": 0.32249945402145386, + "precision": 0.8355108316117581, + "recall": 0.8716755319149065, + "f1": 0.8532101288125946, + "specificity": 0.8283909939593549, + "npv": 0.8658697667424693, + "positive_samples": 1504, + "true_positives": 2242, + "false_positives": 441, + "true_negatives": 2131, + "false_negatives": 330, + "auc_ci": [ + 0.9170891169219639, + 0.9170891169219639 + ], + "precision_ci": [ + 0.8355108316117581, + 0.8355108316117581 + ], + "recall_ci": [ + 0.8716755319149065, + 0.8716755319149065 + ], + "f1_ci": [ + 0.8532101288125946, + 0.8532101288125946 + ], + "specificity_ci": [ + 0.8283909939593549, + 0.8283909939593549 + ], + "npv_ci": [ + 0.8658697667424693, + 0.8658697667424693 + ], + "class_weights": { + "0.0": 0.7064799560680944, + "1.0": 1.7107712765957446 + } + }, + "identity_hate": { + "auc": 0.9171971252566641, + "threshold": 0.055891502648591995, + "precision": 0.8532420335871026, + "recall": 0.829710144927536, + "f1": 0.8413115718720496, + "specificity": 0.8572895277207252, + "npv": 0.8342805841339561, + "positive_samples": 276, + "true_positives": 2134, + "false_positives": 367, + "true_negatives": 2205, + "false_negatives": 438, + "auc_ci": [ + 0.9171971252566641, + 0.9171971252566641 + ], + "precision_ci": [ + 0.8532420335871026, + 0.8532420335871026 + ], + "recall_ci": [ + 0.829710144927536, + 0.829710144927536 + ], + "f1_ci": [ + 0.8413115718720496, + 0.8413115718720496 + ], + "specificity_ci": [ + 0.8572895277207252, + 0.8572895277207252 + ], + "npv_ci": [ + 0.8342805841339561, + 0.8342805841339561 + ], + "class_weights": { + "0.0": 0.5283367556468173, + "1.0": 9.322463768115941 + } + } + }, + "sample_count": 5146 + }, + "6": { + "auc": 0.9462815482574403, + "precision": 0.7134961462135606, + "recall": 0.9073793914943687, + "f1": 0.7744642816056855, + "hamming_loss": 0.15539933230611197, + "exact_match": 0.5132896764252697, + "specificity": 0.8360743701752594, + "class_metrics": { + "toxic": { + "auc": 0.9780732995232411, + "threshold": 0.5710838437080383, + "precision": 0.9379357119021944, + "recall": 0.9243012422360248, + "f1": 0.9310685643115885, + "specificity": 0.9388379204893005, + "npv": 0.9253858836387251, + "positive_samples": 2576, + "true_positives": 2399, + "false_positives": 158, + "true_negatives": 2437, + "false_negatives": 196, + "auc_ci": [ + 0.9780732995232411, + 0.9780732995232411 + ], + "precision_ci": [ + 0.9379357119021944, + 0.9379357119021944 + ], + "recall_ci": [ + 0.9243012422360248, + 0.9243012422360248 + ], + "f1_ci": [ + 0.9310685643115885, + 0.9310685643115885 + ], + "specificity_ci": [ + 0.9388379204893005, + 0.9388379204893005 + ], + "npv_ci": [ + 0.9253858836387251, + 0.9253858836387251 + ], + "class_weights": { + "0.0": 0.9923547400611621, + "1.0": 1.0077639751552796 + } + }, + "severe_toxic": { + "auc": 0.9067576592369966, + "threshold": 0.023807251825928688, + "precision": 0.7794259030353159, + "recall": 0.9380165289256208, + "f1": 0.8513989948241057, + "specificity": 0.7345454545454645, + "npv": 0.9221830255239729, + "positive_samples": 242, + "true_positives": 2435, + "false_positives": 689, + "true_negatives": 1906, + "false_negatives": 160, + "auc_ci": [ + 0.9067576592369966, + 0.9067576592369966 + ], + "precision_ci": [ + 0.7794259030353159, + 0.7794259030353159 + ], + "recall_ci": [ + 0.9380165289256208, + 0.9380165289256208 + ], + "f1_ci": [ + 0.8513989948241057, + 0.8513989948241057 + ], + "specificity_ci": [ + 0.7345454545454645, + 0.7345454545454645 + ], + "npv_ci": [ + 0.9221830255239729, + 0.9221830255239729 + ], + "class_weights": { + "0.0": 0.5244444444444445, + "1.0": 10.727272727272727 + } + }, + "obscene": { + "auc": 0.9375048626461102, + "threshold": 0.14760328829288483, + "precision": 0.8287449241470627, + "recall": 0.9084278768233371, + "f1": 0.8667588986547364, + "specificity": 0.8122789287518954, + "npv": 0.8986867106241987, + "positive_samples": 1234, + "true_positives": 2358, + "false_positives": 487, + "true_negatives": 2108, + "false_negatives": 237, + "auc_ci": [ + 0.9375048626461102, + 0.9375048626461102 + ], + "precision_ci": [ + 0.8287449241470627, + 0.8287449241470627 + ], + "recall_ci": [ + 0.9084278768233371, + 0.9084278768233371 + ], + "f1_ci": [ + 0.8667588986547364, + 0.8667588986547364 + ], + "specificity_ci": [ + 0.8122789287518954, + 0.8122789287518954 + ], + "npv_ci": [ + 0.8986867106241987, + 0.8986867106241987 + ], + "class_weights": { + "0.0": 0.6558868115209702, + "1.0": 2.1037277147487843 + } + }, + "threat": { + "auc": 0.9031869137455802, + "threshold": 0.026773449033498764, + "precision": 0.9112427696973145, + "recall": 0.761467889908257, + "f1": 0.8296498919893159, + "specificity": 0.9258312020460328, + "npv": 0.7951394486538688, + "positive_samples": 109, + "true_positives": 1976, + "false_positives": 192, + "true_negatives": 2403, + "false_negatives": 619, + "auc_ci": [ + 0.9031869137455802, + 0.9031869137455802 + ], + "precision_ci": [ + 0.9112427696973145, + 0.9112427696973145 + ], + "recall_ci": [ + 0.761467889908257, + 0.761467889908257 + ], + "f1_ci": [ + 0.8296498919893159, + 0.8296498919893159 + ], + "specificity_ci": [ + 0.9258312020460328, + 0.9258312020460328 + ], + "npv_ci": [ + 0.7951394486538688, + 0.7951394486538688 + ], + "class_weights": { + "0.0": 0.5107220145583317, + "1.0": 23.81651376146789 + } + }, + "insult": { + "auc": 0.9164838070297321, + "threshold": 0.2600024938583374, + "precision": 0.8178816065079044, + "recall": 0.8940397350993466, + "f1": 0.8542666500534941, + "specificity": 0.8009234111895767, + "npv": 0.8831600262588531, + "positive_samples": 1510, + "true_positives": 2320, + "false_positives": 516, + "true_negatives": 2079, + "false_negatives": 275, + "auc_ci": [ + 0.9164838070297321, + 0.9164838070297321 + ], + "precision_ci": [ + 0.8178816065079044, + 0.8178816065079044 + ], + "recall_ci": [ + 0.8940397350993466, + 0.8940397350993466 + ], + "f1_ci": [ + 0.8542666500534941, + 0.8542666500534941 + ], + "specificity_ci": [ + 0.8009234111895767, + 0.8009234111895767 + ], + "npv_ci": [ + 0.8831600262588531, + 0.8831600262588531 + ], + "class_weights": { + "0.0": 0.7050516023900054, + "1.0": 1.719205298013245 + } + }, + "identity_hate": { + "auc": 0.9038051609994096, + "threshold": 0.03315547853708267, + "precision": 0.8124487711378064, + "recall": 0.8489208633093526, + "f1": 0.8302844808144539, + "specificity": 0.804029304029316, + "npv": 0.8418199125360486, + "positive_samples": 278, + "true_positives": 2203, + "false_positives": 508, + "true_negatives": 2087, + "false_negatives": 392, + "auc_ci": [ + 0.9038051609994096, + 0.9038051609994096 + ], + "precision_ci": [ + 0.8124487711378064, + 0.8124487711378064 + ], + "recall_ci": [ + 0.8489208633093526, + 0.8489208633093526 + ], + "f1_ci": [ + 0.8302844808144539, + 0.8302844808144539 + ], + "specificity_ci": [ + 0.804029304029316, + 0.804029304029316 + ], + "npv_ci": [ + 0.8418199125360486, + 0.8418199125360486 + ], + "class_weights": { + "0.0": 0.5282865282865283, + "1.0": 9.338129496402878 + } + } + }, + "sample_count": 5192 + } + }, + "per_class": {}, + "thresholds": { + "0": { + "toxic": 0.46047261357307434, + "severe_toxic": 0.03537772223353386, + "obscene": 0.2777131497859955, + "threat": 0.016539234668016434, + "insult": 0.25907590985298157, + "identity_hate": 0.026042653247714043 + }, + "1": { + "toxic": 0.44148319959640503, + "severe_toxic": 0.03648429363965988, + "obscene": 0.1990610957145691, + "threat": 0.012619060464203358, + "insult": 0.24214455485343933, + "identity_hate": 0.03167847916483879 + }, + "2": { + "toxic": 0.3978160321712494, + "severe_toxic": 0.015000982210040092, + "obscene": 0.11362762749195099, + "threat": 0.008195769973099232, + "insult": 0.1587354838848114, + "identity_hate": 0.0467526838183403 + }, + "3": { + "toxic": 0.5033379793167114, + "severe_toxic": 0.021415209397673607, + "obscene": 0.14896434545516968, + "threat": 0.013273251242935658, + "insult": 0.22368550300598145, + "identity_hate": 0.042373284697532654 + }, + "4": { + "toxic": 0.4544762372970581, + "severe_toxic": 0.0307308342307806, + "obscene": 0.1775909662246704, + "threat": 0.014509523287415504, + "insult": 0.19917058944702148, + "identity_hate": 0.031982019543647766 + }, + "5": { + "toxic": 0.5313886404037476, + "severe_toxic": 0.05001964047551155, + "obscene": 0.20134443044662476, + "threat": 0.018557138741016388, + "insult": 0.32249945402145386, + "identity_hate": 0.055891502648591995 + }, + "6": { + "toxic": 0.5710838437080383, + "severe_toxic": 0.023807251825928688, + "obscene": 0.14760328829288483, + "threat": 0.026773449033498764, + "insult": 0.2600024938583374, + "identity_hate": 0.03315547853708267 + } + } +} \ No newline at end of file diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_0.png b/evaluation_results/eval_20250208_161149/plots/calibration_0.png new file mode 100644 index 0000000000000000000000000000000000000000..770750eab14da304f5d92cc5fa63185bb99e9008 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e520af6af852f9edeef0bc12c53741ec9028a81d0d7fc7105e7abe02c1121d7 +size 111613 diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_1.png b/evaluation_results/eval_20250208_161149/plots/calibration_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ddfbd55e7263fd2eb6c775c33688ae9d25b65e24 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6bbf38e6f262d27f5209c1bb0b8174259b6183978e8844fb84ccd2b43810be0 +size 111026 diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_2.png b/evaluation_results/eval_20250208_161149/plots/calibration_2.png new file mode 100644 index 0000000000000000000000000000000000000000..a4a6e148bfb16cdf59b9564f9fab3c24f366a8fe --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:617690b2d238fcd53726552b1b979612f943976b2652013a078c6ce4d2496060 +size 110177 diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_3.png b/evaluation_results/eval_20250208_161149/plots/calibration_3.png new file mode 100644 index 0000000000000000000000000000000000000000..82847e711ac36ef6d0d78f3ff1008139c49c69a6 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78fe5b71ba88524f96205ba367b7b643d864a55bcd702e15be7c9a27e2a43007 +size 111311 diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_4.png b/evaluation_results/eval_20250208_161149/plots/calibration_4.png new file mode 100644 index 0000000000000000000000000000000000000000..4aa4cf72682d247fc27db987f80ff2051185f131 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8e2df03bc3e34ccdf6c2a7c5fe305b0da7f4d6185948464ddc67f1ec4618b2f +size 110370 diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_5.png b/evaluation_results/eval_20250208_161149/plots/calibration_5.png new file mode 100644 index 0000000000000000000000000000000000000000..f09b0668148eec4ec9a88019b3a3735f6ab503a7 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4c1e2d2529ebc23a3c1d07daf36507ea66208f49396dce354fc0ab6c8baa14a +size 110324 diff --git a/evaluation_results/eval_20250208_161149/plots/calibration_6.png b/evaluation_results/eval_20250208_161149/plots/calibration_6.png new file mode 100644 index 0000000000000000000000000000000000000000..ba2a7564abce5c839f01f258410f19df72cdc03d --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/calibration_6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d912dedb7a79ea521921afb948696787b2ba6206137d3359b619936e24455101 +size 110780 diff --git a/evaluation_results/eval_20250208_161149/plots/class_calibration.png b/evaluation_results/eval_20250208_161149/plots/class_calibration.png new file mode 100644 index 0000000000000000000000000000000000000000..0c8cb305b432d753dd211bc73e81d1387c41bb8e --- /dev/null +++ b/evaluation_results/eval_20250208_161149/plots/class_calibration.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f0fed51177ba858d2fa386a5198020a00dc58255cb3940526072c2866f71212 +size 111678 diff --git a/evaluation_results/eval_20250208_161149/plots/language_performance.png b/evaluation_results/eval_20250208_161149/plots/language_performance.png new file mode 100644 index 0000000000000000000000000000000000000000..5a1b2b81d36d9482117cec41dacfe6801c9ab507 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/plots/language_performance.png differ diff --git a/evaluation_results/eval_20250208_161149/plots/metric_correlations.png b/evaluation_results/eval_20250208_161149/plots/metric_correlations.png new file mode 100644 index 0000000000000000000000000000000000000000..1ea41f3d102ded809bccfd9ed136da967fe11359 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/plots/metric_correlations.png differ diff --git a/evaluation_results/eval_20250208_161149/plots/overall_calibration.png b/evaluation_results/eval_20250208_161149/plots/overall_calibration.png new file mode 100644 index 0000000000000000000000000000000000000000..2318908216b3565533f22a4a84cc49e1364360ea Binary files /dev/null and b/evaluation_results/eval_20250208_161149/plots/overall_calibration.png differ diff --git a/evaluation_results/eval_20250208_161149/plots/performance_distributions.png b/evaluation_results/eval_20250208_161149/plots/performance_distributions.png new file mode 100644 index 0000000000000000000000000000000000000000..1ec382209540cc46368655dd633ad82c1b705156 Binary files /dev/null and b/evaluation_results/eval_20250208_161149/plots/performance_distributions.png differ diff --git a/evaluation_results/eval_20250208_161149/predictions.npz b/evaluation_results/eval_20250208_161149/predictions.npz new file mode 100644 index 0000000000000000000000000000000000000000..2e1a2ea03585ae1f8063ba63b878acb145fb40fe --- /dev/null +++ b/evaluation_results/eval_20250208_161149/predictions.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d562e6c02fc268d01464f9716846556a75e863ec9cc03d582f39e14191cbd496 +size 809713 diff --git a/evaluation_results/eval_20250208_161149/thresholds.json b/evaluation_results/eval_20250208_161149/thresholds.json new file mode 100644 index 0000000000000000000000000000000000000000..58b1173d1f99d762f9e71f4ed8ffef323c910f50 --- /dev/null +++ b/evaluation_results/eval_20250208_161149/thresholds.json @@ -0,0 +1,58 @@ +{ + "0": { + "toxic": 0.46047261357307434, + "severe_toxic": 0.03537772223353386, + "obscene": 0.2777131497859955, + "threat": 0.016539234668016434, + "insult": 0.25907590985298157, + "identity_hate": 0.026042653247714043 + }, + "1": { + "toxic": 0.44148319959640503, + "severe_toxic": 0.03648429363965988, + "obscene": 0.1990610957145691, + "threat": 0.012619060464203358, + "insult": 0.24214455485343933, + "identity_hate": 0.03167847916483879 + }, + "2": { + "toxic": 0.3978160321712494, + "severe_toxic": 0.015000982210040092, + "obscene": 0.11362762749195099, + "threat": 0.008195769973099232, + "insult": 0.1587354838848114, + "identity_hate": 0.0467526838183403 + }, + "3": { + "toxic": 0.5033379793167114, + "severe_toxic": 0.021415209397673607, + "obscene": 0.14896434545516968, + "threat": 0.013273251242935658, + "insult": 0.22368550300598145, + "identity_hate": 0.042373284697532654 + }, + "4": { + "toxic": 0.4544762372970581, + "severe_toxic": 0.0307308342307806, + "obscene": 0.1775909662246704, + "threat": 0.014509523287415504, + "insult": 0.19917058944702148, + "identity_hate": 0.031982019543647766 + }, + "5": { + "toxic": 0.5313886404037476, + "severe_toxic": 0.05001964047551155, + "obscene": 0.20134443044662476, + "threat": 0.018557138741016388, + "insult": 0.32249945402145386, + "identity_hate": 0.055891502648591995 + }, + "6": { + "toxic": 0.5710838437080383, + "severe_toxic": 0.023807251825928688, + "obscene": 0.14760328829288483, + "threat": 0.026773449033498764, + "insult": 0.2600024938583374, + "identity_hate": 0.03315547853708267 + } +} \ No newline at end of file diff --git a/evaluation_results/eval_20250401_143401/eval_params.json b/evaluation_results/eval_20250401_143401/eval_params.json new file mode 100644 index 0000000000000000000000000000000000000000..eaac7fb2149ff00af0f34cd257595c38e1a4cf0a --- /dev/null +++ b/evaluation_results/eval_20250401_143401/eval_params.json @@ -0,0 +1,21 @@ +{ + "timestamp": "20250401_143401", + "model_path": "weights/toxic_classifier_xlm-roberta-large", + "checkpoint": null, + "test_file": "dataset/split/val.csv", + "batch_size": 64, + "num_workers": 16, + "cache_dir": "cached_data", + "force_retokenize": false, + "prefetch_factor": 2, + "max_length": 128, + "gc_frequency": 500, + "label_columns": [ + "toxic", + "severe_toxic", + "obscene", + "threat", + "insult", + "identity_hate" + ] +} \ No newline at end of file diff --git a/evaluation_results/eval_20250401_143401/evaluation_results.json b/evaluation_results/eval_20250401_143401/evaluation_results.json new file mode 100644 index 0000000000000000000000000000000000000000..13dd851d3c486fd2aabd70a23fa4f2d45a784e4f --- /dev/null +++ b/evaluation_results/eval_20250401_143401/evaluation_results.json @@ -0,0 +1,684 @@ +{ + "default_thresholds": { + "overall": { + "auc_macro": 0.9116120481007194, + "auc_weighted": 0.9305869103434485, + "precision_macro": 0.7017348731216243, + "precision_weighted": 0.7941268867549155, + "recall_macro": 0.4685972374699909, + "recall_weighted": 0.7276981501898812, + "f1_macro": 0.5228946160541719, + "f1_weighted": 0.7469638283202927, + "hamming_loss": 0.08497391889618038, + "exact_match": 0.6461383139828369 + }, + "per_language": { + "0": { + "auc_macro": 0.9445681226397739, + "auc_weighted": 0.9465404082666297, + "precision_macro": 0.7219326082283263, + "precision_weighted": 0.7908382685179838, + "recall_macro": 0.5535398284592582, + "recall_weighted": 0.7833787465940054, + "f1_macro": 0.6000668677340134, + "f1_weighted": 0.7786737821480415, + "hamming_loss": 0.07650567773465575, + "exact_match": 0.6601983613626563, + "sample_count": 4638 + }, + "1": { + "auc_macro": 0.9064189306891727, + "auc_weighted": 0.9274078123911156, + "precision_macro": 0.6864158919056594, + "precision_weighted": 0.7852581089086744, + "recall_macro": 0.44366116589032245, + "recall_weighted": 0.7238780977896851, + "f1_macro": 0.48488161881757197, + "f1_weighted": 0.737051270947713, + "hamming_loss": 0.08752166377816291, + "exact_match": 0.6402849990371654, + "sample_count": 5193 + }, + "2": { + "auc_macro": 0.8945135400492461, + "auc_weighted": 0.9120120071881025, + "precision_macro": 0.7178271955012184, + "precision_weighted": 0.7982113173628885, + "recall_macro": 0.4043111379749362, + "recall_weighted": 0.6535947712418301, + "f1_macro": 0.4738257066120983, + "f1_weighted": 0.7027905834489889, + "hamming_loss": 0.09504905757810483, + "exact_match": 0.6229666924864447, + "sample_count": 5164 + }, + "3": { + "auc_macro": 0.9135727964673032, + "auc_weighted": 0.9339502655719858, + "precision_macro": 0.7093511783545062, + "precision_weighted": 0.7989932896421867, + "recall_macro": 0.4814045378504133, + "recall_weighted": 0.7405478070912451, + "f1_macro": 0.5327086132158053, + "f1_weighted": 0.7545000455696493, + "hamming_loss": 0.08359133126934984, + "exact_match": 0.6480263157894737, + "sample_count": 5168 + }, + "4": { + "auc_macro": 0.9050160058685811, + "auc_weighted": 0.9286663336151794, + "precision_macro": 0.6819384343494851, + "precision_weighted": 0.7945304496145832, + "recall_macro": 0.4656370270227365, + "recall_weighted": 0.7256427604871448, + "f1_macro": 0.5189060171591118, + "f1_weighted": 0.7474398480273773, + "hamming_loss": 0.08477150798267727, + "exact_match": 0.6509598603839442, + "sample_count": 5157 + }, + "5": { + "auc_macro": 0.9115535221829411, + "auc_weighted": 0.9337271942250184, + "precision_macro": 0.6927437323462047, + "precision_weighted": 0.7984424245250574, + "recall_macro": 0.4695924180409275, + "recall_weighted": 0.739629005059022, + "f1_macro": 0.5191221600663896, + "f1_weighted": 0.7554966948679994, + "hamming_loss": 0.08252364295893251, + "exact_match": 0.6525456665371162, + "sample_count": 5146 + }, + "6": { + "auc_macro": 0.9045493247421005, + "auc_weighted": 0.9308415576648513, + "precision_macro": 0.6958021612757893, + "precision_weighted": 0.7925797967619269, + "recall_macro": 0.4680867128534896, + "recall_weighted": 0.735071488645921, + "f1_macro": 0.5184729138243417, + "f1_weighted": 0.7510735996739993, + "hamming_loss": 0.0839753466872111, + "exact_match": 0.6494607087827426, + "sample_count": 5192 + } + }, + "per_class": { + "toxic": { + "auc": 0.9619106577495796, + "threshold": 0.5, + "precision": 0.9067127628925382, + "recall": 0.8891902582358592, + "f1": 0.8978660276161132, + "support": 17697, + "brier": 0.09342169378057544, + "true_positives": 15736, + "false_positives": 1619, + "true_negatives": 16342, + "false_negatives": 1961 + }, + "severe_toxic": { + "auc": 0.9017555053121755, + "threshold": 0.5, + "precision": 0.5620915032679739, + "recall": 0.15589123867069488, + "f1": 0.24408703878902555, + "support": 1655, + "brier": 0.05564494143865772, + "true_positives": 258, + "false_positives": 201, + "true_negatives": 33802, + "false_negatives": 1397 + }, + "obscene": { + "auc": 0.9247491461802884, + "threshold": 0.5, + "precision": 0.7636434008515031, + "recall": 0.686181312311616, + "f1": 0.7228430115405752, + "support": 8626, + "brier": 0.1102165916686836, + "true_positives": 5919, + "false_positives": 1832, + "true_negatives": 25200, + "false_negatives": 2707 + }, + "threat": { + "auc": 0.8978719938708597, + "threshold": 0.5, + "precision": 0.6042553191489362, + "recall": 0.1868421052631579, + "f1": 0.28542713567839195, + "support": 760, + "brier": 0.03694216309848939, + "true_positives": 142, + "false_positives": 93, + "true_negatives": 34805, + "false_negatives": 618 + }, + "insult": { + "auc": 0.8962985964590791, + "threshold": 0.5, + "precision": 0.6981960484871623, + "recall": 0.7172271791352093, + "f1": 0.7075836718901142, + "support": 10199, + "brier": 0.1366709113756841, + "true_positives": 7315, + "false_positives": 3162, + "true_negatives": 22297, + "false_negatives": 2884 + }, + "identity_hate": { + "auc": 0.887086389032334, + "threshold": 0.5, + "precision": 0.6755102040816326, + "recall": 0.17625133120340788, + "f1": 0.2795608108108108, + "support": 1878, + "brier": 0.06076370760519854, + "true_positives": 331, + "false_positives": 159, + "true_negatives": 33621, + "false_negatives": 1547 + } + } + }, + "optimized_thresholds": { + "overall": { + "auc_macro": 0.9116120481007194, + "auc_weighted": 0.9305869103434485, + "precision_macro": 0.5775888380947196, + "precision_weighted": 0.7443465124836487, + "recall_macro": 0.639900823721825, + "recall_weighted": 0.798186941075585, + "f1_macro": 0.6040131510667749, + "f1_weighted": 0.7686775463209056, + "hamming_loss": 0.09459775272496121, + "exact_match": 0.6191317516405855 + }, + "per_language": { + "0": { + "auc_macro": 0.9445681226397739, + "auc_weighted": 0.9465404082666297, + "precision_macro": 0.5885969911405202, + "precision_weighted": 0.7416734521846035, + "recall_macro": 0.7381385425477333, + "recall_weighted": 0.8514986376021798, + "f1_macro": 0.6497623010487168, + "f1_weighted": 0.7903759805291908, + "hamming_loss": 0.08746586172200661, + "exact_match": 0.6282880551962052, + "sample_count": 4638 + }, + "1": { + "auc_macro": 0.9064189306891727, + "auc_weighted": 0.9274078123911156, + "precision_macro": 0.5769491938694048, + "precision_weighted": 0.7372462490399235, + "recall_macro": 0.6223651765807731, + "recall_weighted": 0.7957133288680509, + "f1_macro": 0.5940383621467368, + "f1_weighted": 0.7630519259035966, + "hamming_loss": 0.09734257654534952, + "exact_match": 0.6112073945696129, + "sample_count": 5193 + }, + "2": { + "auc_macro": 0.8945135400492461, + "auc_weighted": 0.9120120071881025, + "precision_macro": 0.5883546567568967, + "precision_weighted": 0.7471472711374241, + "recall_macro": 0.5741089328356292, + "recall_weighted": 0.7323613205966147, + "f1_macro": 0.579910490554519, + "f1_weighted": 0.7393192722268676, + "hamming_loss": 0.10030983733539892, + "exact_match": 0.6094113090627421, + "sample_count": 5164 + }, + "3": { + "auc_macro": 0.9135727964673032, + "auc_weighted": 0.9339502655719858, + "precision_macro": 0.5674300764951785, + "precision_weighted": 0.7452385794349706, + "recall_macro": 0.6585754182827804, + "recall_weighted": 0.8117963367501261, + "f1_macro": 0.6075512335059755, + "f1_weighted": 0.7751847838928642, + "hamming_loss": 0.09404024767801858, + "exact_match": 0.6234520123839009, + "sample_count": 5168 + }, + "4": { + "auc_macro": 0.9050160058685811, + "auc_weighted": 0.9286663336151794, + "precision_macro": 0.5635774868138544, + "precision_weighted": 0.7453012013072762, + "recall_macro": 0.6307198572670079, + "recall_weighted": 0.793640054127199, + "f1_macro": 0.5906173214394316, + "f1_weighted": 0.7663604150980545, + "hamming_loss": 0.0963415422403206, + "exact_match": 0.6162497576110142, + "sample_count": 5157 + }, + "5": { + "auc_macro": 0.9115535221829411, + "auc_weighted": 0.9337271942250184, + "precision_macro": 0.577007586897046, + "precision_weighted": 0.7468873881119108, + "recall_macro": 0.635638229939968, + "recall_weighted": 0.8080944350758853, + "f1_macro": 0.5988862551226474, + "f1_weighted": 0.7742215916662522, + "hamming_loss": 0.09350304443580774, + "exact_match": 0.6195102992615624, + "sample_count": 5146 + }, + "6": { + "auc_macro": 0.9045493247421005, + "auc_weighted": 0.9308415576648513, + "precision_macro": 0.591572349044604, + "precision_weighted": 0.749047954356656, + "recall_macro": 0.6294384348455582, + "recall_weighted": 0.8016820857863751, + "f1_macro": 0.6039252504591597, + "f1_weighted": 0.772582192067038, + "hamming_loss": 0.09244992295839753, + "exact_match": 0.6267334360554699, + "sample_count": 5192 + } + }, + "per_class": { + "toxic": { + "auc": 0.9619106577495796, + "threshold": 0.4877551020408163, + "precision": 0.8999716472923164, + "recall": 0.8968186698310449, + "f1": 0.8983923921657421, + "support": 17697, + "brier": 0.09342169378057544, + "true_positives": 15871, + "false_positives": 1764, + "true_negatives": 16197, + "false_negatives": 1826 + }, + "severe_toxic": { + "auc": 0.9017555053121755, + "threshold": 0.373469387755102, + "precision": 0.34626149540183926, + "recall": 0.5232628398791541, + "f1": 0.4167468719923003, + "support": 1655, + "brier": 0.05564494143865772, + "true_positives": 866, + "false_positives": 1635, + "true_negatives": 32368, + "false_negatives": 789 + }, + "obscene": { + "auc": 0.9247491461802884, + "threshold": 0.4551020408163265, + "precision": 0.7017099430018999, + "recall": 0.770693252956179, + "f1": 0.734585635359116, + "support": 8626, + "brier": 0.1102165916686836, + "true_positives": 6648, + "false_positives": 2826, + "true_negatives": 24206, + "false_negatives": 1978 + }, + "threat": { + "auc": 0.8978719938708597, + "threshold": 0.38979591836734695, + "precision": 0.43684992570579495, + "recall": 0.3868421052631579, + "f1": 0.41032798325191905, + "support": 760, + "brier": 0.03694216309848939, + "true_positives": 294, + "false_positives": 379, + "true_negatives": 34519, + "false_negatives": 466 + }, + "insult": { + "auc": 0.8962985964590791, + "threshold": 0.463265306122449, + "precision": 0.6568989575638184, + "recall": 0.7846847730169625, + "f1": 0.7151282280403896, + "support": 10199, + "brier": 0.1366709113756841, + "true_positives": 8003, + "false_positives": 4180, + "true_negatives": 21279, + "false_negatives": 2196 + }, + "identity_hate": { + "auc": 0.887086389032334, + "threshold": 0.373469387755102, + "precision": 0.423841059602649, + "recall": 0.47710330138445156, + "f1": 0.44889779559118237, + "support": 1878, + "brier": 0.06076370760519854, + "true_positives": 896, + "false_positives": 1218, + "true_negatives": 32562, + "false_negatives": 982 + } + } + }, + "thresholds": { + "global": { + "toxic": { + "threshold": 0.4877551020408163, + "f1_score": 0.8926184748925591, + "support": 17697, + "total_samples": 35658 + }, + "severe_toxic": { + "threshold": 0.373469387755102, + "f1_score": 0.41132469871513055, + "support": 1655, + "total_samples": 35658 + }, + "obscene": { + "threshold": 0.4551020408163265, + "f1_score": 0.726924984126118, + "support": 8626, + "total_samples": 35658 + }, + "threat": { + "threshold": 0.38979591836734695, + "f1_score": 0.41018044345470683, + "support": 760, + "total_samples": 35658 + }, + "insult": { + "threshold": 0.463265306122449, + "f1_score": 0.7104171976414078, + "support": 10199, + "total_samples": 35658 + }, + "identity_hate": { + "threshold": 0.373469387755102, + "f1_score": 0.4444212159518569, + "support": 1878, + "total_samples": 35658 + } + }, + "per_language": { + "0": { + "toxic": { + "threshold": 0.4379310344827586, + "f1_score": 0.6362062357467935, + "support": 2228, + "total_samples": 4638 + }, + "severe_toxic": { + "threshold": 0.4241379310344827, + "f1_score": 0.6836346572759443, + "support": 199, + "total_samples": 4638 + }, + "obscene": { + "threshold": 0.4655172413793103, + "f1_score": 0.4812423489705398, + "support": 1235, + "total_samples": 4638 + }, + "threat": { + "threshold": 0.4655172413793103, + "f1_score": 0.560716193430073, + "support": 118, + "total_samples": 4638 + }, + "insult": { + "threshold": 0.6586206896551723, + "f1_score": 0.6797683196093679, + "support": 1144, + "total_samples": 4638 + }, + "identity_hate": { + "threshold": 0.6310344827586206, + "f1_score": 0.4653856089660791, + "support": 214, + "total_samples": 4638 + } + }, + "1": { + "toxic": { + "threshold": 0.38275862068965516, + "f1_score": 0.5653885349662379, + "support": 2589, + "total_samples": 5193 + }, + "severe_toxic": { + "threshold": 0.36896551724137927, + "f1_score": 0.6303988062940857, + "support": 245, + "total_samples": 5193 + }, + "obscene": { + "threshold": 0.6724137931034482, + "f1_score": 0.69776888519452, + "support": 1239, + "total_samples": 5193 + }, + "threat": { + "threshold": 0.5482758620689655, + "f1_score": 0.49444444444444446, + "support": 106, + "total_samples": 5193 + }, + "insult": { + "threshold": 0.45172413793103444, + "f1_score": 0.43592427815977264, + "support": 1514, + "total_samples": 5193 + }, + "identity_hate": { + "threshold": 0.603448275862069, + "f1_score": 0.437278850182076, + "support": 279, + "total_samples": 5193 + } + }, + "2": { + "toxic": { + "threshold": 0.36896551724137927, + "f1_score": 0.5636259188109024, + "support": 2585, + "total_samples": 5164 + }, + "severe_toxic": { + "threshold": 0.396551724137931, + "f1_score": 0.6242565552619788, + "support": 243, + "total_samples": 5164 + }, + "obscene": { + "threshold": 0.6310344827586206, + "f1_score": 0.609064783177638, + "support": 1233, + "total_samples": 5164 + }, + "threat": { + "threshold": 0.6862068965517241, + "f1_score": 0.4331632653061225, + "support": 110, + "total_samples": 5164 + }, + "insult": { + "threshold": 0.6586206896551723, + "f1_score": 0.5919194590653671, + "support": 1514, + "total_samples": 5164 + }, + "identity_hate": { + "threshold": 0.5896551724137931, + "f1_score": 0.44181963497241983, + "support": 282, + "total_samples": 5164 + } + }, + "3": { + "toxic": { + "threshold": 0.35517241379310344, + "f1_score": 0.5733103161693534, + "support": 2579, + "total_samples": 5168 + }, + "severe_toxic": { + "threshold": 0.38275862068965516, + "f1_score": 0.6597492750378473, + "support": 243, + "total_samples": 5168 + }, + "obscene": { + "threshold": 0.5896551724137931, + "f1_score": 0.5803338639295222, + "support": 1234, + "total_samples": 5168 + }, + "threat": { + "threshold": 0.5896551724137931, + "f1_score": 0.5531975271105706, + "support": 108, + "total_samples": 5168 + }, + "insult": { + "threshold": 0.4103448275862069, + "f1_score": 0.43932768516388326, + "support": 1511, + "total_samples": 5168 + }, + "identity_hate": { + "threshold": 0.5482758620689655, + "f1_score": 0.5223443223443224, + "support": 276, + "total_samples": 5168 + } + }, + "4": { + "toxic": { + "threshold": 0.36896551724137927, + "f1_score": 0.5671790360963849, + "support": 2568, + "total_samples": 5157 + }, + "severe_toxic": { + "threshold": 0.4241379310344827, + "f1_score": 0.6449236298292902, + "support": 240, + "total_samples": 5157 + }, + "obscene": { + "threshold": 0.5896551724137931, + "f1_score": 0.5763915317957939, + "support": 1225, + "total_samples": 5157 + }, + "threat": { + "threshold": 0.5482758620689655, + "f1_score": 0.5202898550724637, + "support": 105, + "total_samples": 5157 + }, + "insult": { + "threshold": 0.45172413793103444, + "f1_score": 0.44168323420099964, + "support": 1501, + "total_samples": 5157 + }, + "identity_hate": { + "threshold": 0.5344827586206896, + "f1_score": 0.3050612442147916, + "support": 273, + "total_samples": 5157 + } + }, + "5": { + "toxic": { + "threshold": 0.38275862068965516, + "f1_score": 0.5689208863252881, + "support": 2572, + "total_samples": 5146 + }, + "severe_toxic": { + "threshold": 0.38275862068965516, + "f1_score": 0.6483406115143644, + "support": 242, + "total_samples": 5146 + }, + "obscene": { + "threshold": 0.6172413793103448, + "f1_score": 0.7591744574190955, + "support": 1227, + "total_samples": 5146 + }, + "threat": { + "threshold": 0.5896551724137931, + "f1_score": 0.48909813468905516, + "support": 106, + "total_samples": 5146 + }, + "insult": { + "threshold": 0.4655172413793103, + "f1_score": 0.4438765689644482, + "support": 1506, + "total_samples": 5146 + }, + "identity_hate": { + "threshold": 0.4655172413793103, + "f1_score": 0.57592394533571, + "support": 277, + "total_samples": 5146 + } + }, + "6": { + "toxic": { + "threshold": 0.396551724137931, + "f1_score": 0.5707684299142913, + "support": 2576, + "total_samples": 5192 + }, + "severe_toxic": { + "threshold": 0.38275862068965516, + "f1_score": 0.6300280234278585, + "support": 243, + "total_samples": 5192 + }, + "obscene": { + "threshold": 0.603448275862069, + "f1_score": 0.5508854395728676, + "support": 1233, + "total_samples": 5192 + }, + "threat": { + "threshold": 0.4655172413793103, + "f1_score": 0.6029992790194665, + "support": 107, + "total_samples": 5192 + }, + "insult": { + "threshold": 0.4241379310344827, + "f1_score": 0.4434943555473952, + "support": 1509, + "total_samples": 5192 + }, + "identity_hate": { + "threshold": 0.6586206896551723, + "f1_score": 0.4569864410513042, + "support": 277, + "total_samples": 5192 + } + } + } + } +} \ No newline at end of file diff --git a/evaluation_results/eval_20250401_143401/plots/per_class_comparison.png b/evaluation_results/eval_20250401_143401/plots/per_class_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..d858f7b2de7d51ed46e81e5380e6c7e6b564d52a Binary files /dev/null and b/evaluation_results/eval_20250401_143401/plots/per_class_comparison.png differ diff --git a/evaluation_results/eval_20250401_143401/plots/roc_all_classes.png b/evaluation_results/eval_20250401_143401/plots/roc_all_classes.png new file mode 100644 index 0000000000000000000000000000000000000000..da0f127480852ab04904aa3d35ec256a1dce69a8 --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_all_classes.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc99cf8a318efe9bde206d2e875905d037044c43b6d67f4a44cce849f30d2d0b +size 324306 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_by_language.png b/evaluation_results/eval_20250401_143401/plots/roc_by_language.png new file mode 100644 index 0000000000000000000000000000000000000000..c50989fe73830970ed85659369d6de35969f46f5 --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_by_language.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26176df08c42f1841e5cafec0f988b05fe54f53b820a028a3bf574f48bf52839 +size 286397 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_identity_hate.png b/evaluation_results/eval_20250401_143401/plots/roc_identity_hate.png new file mode 100644 index 0000000000000000000000000000000000000000..4efca5039444e1aeb746073e41dc9b3704f983a7 --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_identity_hate.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0673fb7730bdc288819a8ece4a0c7232915a1702c35781911560505ff3796b02 +size 198630 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_insult.png b/evaluation_results/eval_20250401_143401/plots/roc_insult.png new file mode 100644 index 0000000000000000000000000000000000000000..fd498cbbc7c4fb3992692bb0c6f6ad566369a026 --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_insult.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46291fb38eaf918534dcd4541186258d0c6acf66990023cafa396d4f0f72760a +size 182740 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_obscene.png b/evaluation_results/eval_20250401_143401/plots/roc_obscene.png new file mode 100644 index 0000000000000000000000000000000000000000..039b46278d2eeca5ccf93fb494a3090334aadef4 --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_obscene.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b99a652961562a3f3208601fae23e73b79935159b8fba6c2857eedfae54637bd +size 179325 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_severe_toxic.png b/evaluation_results/eval_20250401_143401/plots/roc_severe_toxic.png new file mode 100644 index 0000000000000000000000000000000000000000..247cd6bceb40ee9827d6a59c9f239630120efd1b --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_severe_toxic.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:deb1e67ce887d64f22a7fbf225e0393e4b93d5f242c1803286d8be1f8ee3c5db +size 196608 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_threat.png b/evaluation_results/eval_20250401_143401/plots/roc_threat.png new file mode 100644 index 0000000000000000000000000000000000000000..3f1c2e63a22e6b38c8f520ad2f2c137081faac8d --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_threat.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1bf854e4f769dd02ddc7ee955d6d8ac77f12879d6ca09cb553684a275226d167 +size 195438 diff --git a/evaluation_results/eval_20250401_143401/plots/roc_toxic.png b/evaluation_results/eval_20250401_143401/plots/roc_toxic.png new file mode 100644 index 0000000000000000000000000000000000000000..a3078eeb4a44eceedf978ba4fecc751372cd6f1b --- /dev/null +++ b/evaluation_results/eval_20250401_143401/plots/roc_toxic.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0aced50965a616578938c4d047bafb95c1c89b97d58073900d24a519c0616e5c +size 169233 diff --git a/evaluation_results/eval_20250401_143401/plots/threshold_comparison.png b/evaluation_results/eval_20250401_143401/plots/threshold_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..427193a38eda9934803b5ac1894f73861f6d7bad Binary files /dev/null and b/evaluation_results/eval_20250401_143401/plots/threshold_comparison.png differ diff --git a/evaluation_results/eval_20250401_143401/predictions.npz b/evaluation_results/eval_20250401_143401/predictions.npz new file mode 100644 index 0000000000000000000000000000000000000000..af19061234d038b5154e4f2f019ac3e18c9b24ee --- /dev/null +++ b/evaluation_results/eval_20250401_143401/predictions.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e60d667324ec828a3b0d76b860f1ac1df7df1ef76e8084dfaf27a7b145d0652 +size 783527 diff --git a/images/class_distribution.png b/images/class_distribution.png new file mode 100644 index 0000000000000000000000000000000000000000..08632ecc19d9186f815f6e9ae35425affe678567 --- /dev/null +++ b/images/class_distribution.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a61a82e8a07799e47d3d51b99959dcd12d67d921653f4743e8ee9f695c234a1 +size 258031 diff --git a/images/language_distribution.png b/images/language_distribution.png new file mode 100644 index 0000000000000000000000000000000000000000..9d3410e1cf81af8547cf15b6fdf0ac26ba011db3 --- /dev/null +++ b/images/language_distribution.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a432ebd6167e035191312be66f76751fdac83c2e852d3ad274686d2a60eef646 +size 160711 diff --git a/images/toxicity_by_language.png b/images/toxicity_by_language.png new file mode 100644 index 0000000000000000000000000000000000000000..bf02301dea101b3d3b63732ef1d34a4e155b6c70 --- /dev/null +++ b/images/toxicity_by_language.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fd33621693b53c270af1aa104da285c17a99afc0dfa7a74308b94d882642dc4 +size 213112 diff --git a/images/toxicity_correlation.png b/images/toxicity_correlation.png new file mode 100644 index 0000000000000000000000000000000000000000..3b270ed5a6e18e0f901b96db06e38dc89f503576 --- /dev/null +++ b/images/toxicity_correlation.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdf9ef87edccfe54ec7403459ef909402e9b4863d736803da4eb2e5e7c329ef7 +size 268739 diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9cba6166669c84f25c730710d60301568ccc1456 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,3 @@ +""" +Model package for toxic comment classification. +""" \ No newline at end of file diff --git a/model/data/sampler.py b/model/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8acbd8f30970c899a2df9bbdfe50da3d8d0c3952 --- /dev/null +++ b/model/data/sampler.py @@ -0,0 +1,56 @@ +from torch.utils.data import Sampler +import numpy as np +import logging +from collections import defaultdict +from pathlib import Path +import torch + +logger = logging.getLogger(__name__) + +class MultilabelStratifiedSampler(Sampler): + def __init__(self, labels, groups, batch_size, cached_size=None): + super().__init__(None) + self.labels = np.array(labels) + self.groups = np.array(groups) + self.batch_size = batch_size + self.num_samples = len(labels) + + # Simple validation + if len(self.labels) != len(self.groups): + raise ValueError("Length mismatch between labels and groups") + + # Create indices per group + self.group_indices = {} + unique_groups = np.unique(self.groups) + + for group in unique_groups: + indices = np.where(self.groups == group)[0] + if len(indices) > 0: + self.group_indices[group] = indices + + # Calculate group probabilities + group_sizes = np.array([len(indices) for indices in self.group_indices.values()]) + self.group_probs = group_sizes / group_sizes.sum() + self.valid_groups = list(self.group_indices.keys()) + + # Calculate number of batches + self.num_batches = self.num_samples // self.batch_size + if self.num_batches == 0: + self.num_batches = 1 + self.total_samples = self.num_batches * self.batch_size + + def __iter__(self): + indices = [] + for _ in range(self.num_batches): + batch = [] + for _ in range(self.batch_size): + # Select group and sample from it + group = np.random.choice(self.valid_groups, p=self.group_probs) + idx = np.random.choice(self.group_indices[group]) + batch.append(idx) + indices.extend(batch) + + return iter(indices) + + def __len__(self): + return self.total_samples \ No newline at end of file diff --git a/model/evaluation/evaluate.py b/model/evaluation/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..2442c6b1cfd320941a89cff8a4caf92994f5d49a --- /dev/null +++ b/model/evaluation/evaluate.py @@ -0,0 +1,745 @@ +import torch +from model.language_aware_transformer import LanguageAwareTransformer +from transformers import XLMRobertaTokenizer +import pandas as pd +import numpy as np +from sklearn.metrics import ( + roc_auc_score, precision_recall_fscore_support, + confusion_matrix, hamming_loss, + accuracy_score, precision_score, recall_score, f1_score, + brier_score_loss +) +from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.model_selection import GridSearchCV +import matplotlib.pyplot as plt +from tqdm import tqdm +import json +import os +from datetime import datetime +import argparse +from torch.utils.data import Dataset, DataLoader +import gc +import multiprocessing +from pathlib import Path +import hashlib +import logging +from sklearn.metrics import make_scorer + +# Set matplotlib to non-interactive backend +plt.switch_backend('agg') + +# Set memory optimization environment variables +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True' + +logger = logging.getLogger(__name__) + +class ToxicDataset(Dataset): + def __init__(self, df, tokenizer, config): + self.df = df + self.tokenizer = tokenizer + self.config = config + + # Ensure label columns are defined + if not hasattr(config, 'label_columns'): + self.label_columns = [ + 'toxic', 'severe_toxic', 'obscene', + 'threat', 'insult', 'identity_hate' + ] + logger.warning("Label columns not provided in config, using defaults") + else: + self.label_columns = config.label_columns + + # Verify all label columns exist in DataFrame + missing_columns = [col for col in self.label_columns if col not in df.columns] + if missing_columns: + raise ValueError(f"Missing label columns in dataset: {missing_columns}") + + # Convert labels to numpy array for efficiency + self.labels = df[self.label_columns].values + + # Create language mapping + self.lang_to_id = { + 'en': 0, 'ru': 1, 'tr': 2, 'es': 3, + 'fr': 4, 'it': 5, 'pt': 6 + } + + # Convert language codes to numeric indices + self.langs = np.array([self.lang_to_id.get(lang, 0) for lang in df['lang']]) + + print(f"Initialized dataset with {len(self)} samples") + logger.info(f"Dataset initialized with {len(self)} samples") + logger.info(f"Label columns: {self.label_columns}") + logger.info(f"Unique languages: {np.unique(df['lang'])}") + logger.info(f"Language mapping: {self.lang_to_id}") + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + if idx % 1000 == 0: + print(f"Loading sample {idx}") + logger.debug(f"Loading sample {idx}") + + # Get text and labels + text = self.df.iloc[idx]['comment_text'] + labels = torch.FloatTensor(self.labels[idx]) + lang = torch.tensor(self.langs[idx], dtype=torch.long) # Ensure long dtype + + # Tokenize text + encoding = self.tokenizer( + text, + add_special_tokens=True, + max_length=self.config.max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].squeeze(0), + 'attention_mask': encoding['attention_mask'].squeeze(0), + 'labels': labels, + 'lang': lang + } + +class ThresholdOptimizer(BaseEstimator, ClassifierMixin): + """Custom estimator for threshold optimization""" + def __init__(self, threshold=0.5): + self.threshold = threshold + self.probabilities_ = None + + def fit(self, X, y): + # Store probabilities for prediction + self.probabilities_ = X + return self + + def predict(self, X): + # Apply threshold to probabilities + return (X > self.threshold).astype(int) + + def score(self, X, y): + # Return F1 score with proper handling of edge cases + predictions = self.predict(X) + + # Handle edge case where all samples are negative + if y.sum() == 0: + return 1.0 if predictions.sum() == 0 else 0.0 + + # Calculate metrics with zero_division=1 + try: + precision = precision_score(y, predictions, zero_division=1) + recall = recall_score(y, predictions, zero_division=1) + + # Calculate F1 manually to avoid warnings + if precision + recall == 0: + return 0.0 + f1 = 2 * (precision * recall) / (precision + recall) + return f1 + except Exception: + return 0.0 + +def load_model(model_path): + """Load model and tokenizer from versioned checkpoint directory""" + try: + # Check if model_path points to a specific checkpoint or base directory + model_dir = Path(model_path) + if model_dir.is_dir(): + # Check for 'latest' symlink first + latest_link = model_dir / 'latest' + if latest_link.exists() and latest_link.is_symlink(): + model_dir = latest_link.resolve() + logger.info(f"Using latest checkpoint: {model_dir}") + else: + # Find most recent checkpoint + checkpoints = sorted([ + d for d in model_dir.iterdir() + if d.is_dir() and d.name.startswith('checkpoint_epoch') + ]) + if checkpoints: + model_dir = checkpoints[-1] + logger.info(f"Using most recent checkpoint: {model_dir}") + else: + logger.info("No checkpoints found, using base directory") + + logger.info(f"Loading model from: {model_dir}") + + # Initialize the custom model architecture + model = LanguageAwareTransformer( + num_labels=6, + hidden_size=1024, + num_attention_heads=16, + model_name='xlm-roberta-large' + ) + + # Load the trained weights + weights_path = model_dir / 'pytorch_model.bin' + if not weights_path.exists(): + raise FileNotFoundError(f"Model weights not found at {weights_path}") + + state_dict = torch.load(weights_path) + model.load_state_dict(state_dict) + logger.info("Model weights loaded successfully") + + # Load base XLM-RoBERTa tokenizer directly + logger.info("Loading XLM-RoBERTa tokenizer...") + tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + + # Load training metadata if available + metadata_path = model_dir / 'metadata.json' + if metadata_path.exists(): + with open(metadata_path) as f: + metadata = json.load(f) + logger.info(f"Loaded checkpoint metadata: Epoch {metadata.get('epoch', 'unknown')}") + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model.eval() + + return model, tokenizer, device + + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + return None, None, None + +def optimize_threshold(y_true, y_pred_proba, n_steps=50): + """ + Optimize threshold using grid search to maximize F1 score + """ + # Handle edge case where all samples are negative + if y_true.sum() == 0: + return { + 'threshold': 0.5, # Use default threshold + 'f1_score': 1.0, # Perfect score for all negative samples + 'support': 0, + 'total_samples': len(y_true) + } + + # Create parameter grid + param_grid = { + 'threshold': np.linspace(0.3, 0.7, n_steps) + } + + # Initialize optimizer + optimizer = ThresholdOptimizer() + + # Run grid search with custom scoring + grid_search = GridSearchCV( + optimizer, + param_grid, + scoring=make_scorer(f1_score, zero_division=1), + cv=5, + n_jobs=-1, + verbose=0 + ) + + # Reshape probabilities to 2D array + X = y_pred_proba.reshape(-1, 1) + + # Fit grid search + grid_search.fit(X, y_true) + + # Get best results + best_threshold = grid_search.best_params_['threshold'] + best_f1 = grid_search.best_score_ + + return { + 'threshold': float(best_threshold), + 'f1_score': float(best_f1), + 'support': int(y_true.sum()), + 'total_samples': len(y_true) + } + +def calculate_optimal_thresholds(predictions, labels, langs): + """Calculate optimal thresholds for each class and language combination using Bayesian optimization""" + logger.info("Calculating optimal thresholds using Bayesian optimization...") + + toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + unique_langs = np.unique(langs) + + thresholds = { + 'global': {}, + 'per_language': {} + } + + # Calculate global thresholds + logger.info("Computing global thresholds...") + for i, class_name in enumerate(tqdm(toxicity_types, desc="Global thresholds")): + thresholds['global'][class_name] = optimize_threshold( + labels[:, i], + predictions[:, i], + n_steps=50 + ) + + # Calculate language-specific thresholds + logger.info("Computing language-specific thresholds...") + for lang in tqdm(unique_langs, desc="Language thresholds"): + lang_mask = langs == lang + if not lang_mask.any(): + continue + + thresholds['per_language'][str(lang)] = {} + lang_preds = predictions[lang_mask] + lang_labels = labels[lang_mask] + + for i, class_name in enumerate(toxicity_types): + # Only optimize if we have enough samples + if lang_labels[:, i].sum() >= 100: # Minimum samples threshold + thresholds['per_language'][str(lang)][class_name] = optimize_threshold( + lang_labels[:, i], + lang_preds[:, i], + n_steps=30 # Fewer iterations for per-language optimization + ) + else: + # Use global threshold if not enough samples + thresholds['per_language'][str(lang)][class_name] = thresholds['global'][class_name] + + return thresholds + +def evaluate_model(model, val_loader, device, output_dir): + """Evaluate model performance on validation set""" + model.eval() + all_predictions = [] + all_labels = [] + all_langs = [] + + total_samples = len(val_loader.dataset) + total_batches = len(val_loader) + + logger.info(f"\nStarting evaluation on {total_samples:,} samples in {total_batches} batches") + progress_bar = tqdm( + val_loader, + desc="Evaluating", + total=total_batches, + unit="batch", + ncols=100, + bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' + ) + + with torch.inference_mode(): + for batch in progress_bar: + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + labels = batch['labels'].cpu().numpy() + langs = batch['lang'].cpu().numpy() + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + lang_ids=batch['lang'].to(device) + ) + + predictions = outputs['probabilities'].cpu().numpy() + + all_predictions.append(predictions) + all_labels.append(labels) + all_langs.append(langs) + + # Update progress bar description with batch size + progress_bar.set_description(f"Processed batch ({len(input_ids)} samples)") + + # Concatenate all batches with progress bar + logger.info("\nProcessing results...") + predictions = np.vstack(all_predictions) + labels = np.vstack(all_labels) + langs = np.concatenate(all_langs) + + logger.info(f"Computing metrics for {len(predictions):,} samples...") + + # Calculate metrics with progress indication + results = calculate_metrics(predictions, labels, langs) + + # Save results with progress indication + logger.info("Saving evaluation results...") + save_results( + results=results, + predictions=predictions, + labels=labels, + langs=langs, + output_dir=output_dir + ) + + # Plot metrics + logger.info("Generating metric plots...") + plot_metrics(results, output_dir, predictions=predictions, labels=labels) + + logger.info("Evaluation complete!") + return results, predictions + +def calculate_metrics(predictions, labels, langs): + """Calculate detailed metrics""" + results = { + 'default_thresholds': { + 'overall': {}, + 'per_language': {}, + 'per_class': {} + }, + 'optimized_thresholds': { + 'overall': {}, + 'per_language': {}, + 'per_class': {} + } + } + + # Default threshold of 0.5 + DEFAULT_THRESHOLD = 0.5 + + # Calculate metrics with default threshold + logger.info("Computing metrics with default threshold (0.5)...") + binary_predictions_default = (predictions > DEFAULT_THRESHOLD).astype(int) + results['default_thresholds']['overall'] = calculate_overall_metrics(predictions, labels, binary_predictions_default) + + # Calculate per-language metrics with default threshold + unique_langs = np.unique(langs) + logger.info(f"Computing per-language metrics with default threshold...") + for lang in tqdm(unique_langs, desc="Language metrics (default)", ncols=100): + lang_mask = langs == lang + if not lang_mask.any(): + continue + + lang_preds = predictions[lang_mask] + lang_labels = labels[lang_mask] + lang_binary_preds = binary_predictions_default[lang_mask] + + results['default_thresholds']['per_language'][str(lang)] = calculate_overall_metrics( + lang_preds, lang_labels, lang_binary_preds + ) + results['default_thresholds']['per_language'][str(lang)]['sample_count'] = int(lang_mask.sum()) + + # Calculate per-class metrics with default threshold + toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + logger.info("Computing per-class metrics with default threshold...") + for i, class_name in enumerate(tqdm(toxicity_types, desc="Class metrics (default)", ncols=100)): + results['default_thresholds']['per_class'][class_name] = calculate_class_metrics( + labels[:, i], + predictions[:, i], + binary_predictions_default[:, i], + DEFAULT_THRESHOLD + ) + + # Calculate optimal thresholds and corresponding metrics + logger.info("Computing optimal thresholds...") + thresholds = calculate_optimal_thresholds(predictions, labels, langs) + + # Apply optimal thresholds + logger.info("Computing metrics with optimized thresholds...") + binary_predictions_opt = np.zeros_like(predictions, dtype=int) + for i, class_name in enumerate(toxicity_types): + opt_threshold = thresholds['global'][class_name]['threshold'] + binary_predictions_opt[:, i] = (predictions[:, i] > opt_threshold).astype(int) + + # Calculate overall metrics with optimized thresholds + results['optimized_thresholds']['overall'] = calculate_overall_metrics(predictions, labels, binary_predictions_opt) + + # Calculate per-language metrics with optimized thresholds + logger.info(f"Computing per-language metrics with optimized thresholds...") + for lang in tqdm(unique_langs, desc="Language metrics (optimized)", ncols=100): + lang_mask = langs == lang + if not lang_mask.any(): + continue + + lang_preds = predictions[lang_mask] + lang_labels = labels[lang_mask] + lang_binary_preds = binary_predictions_opt[lang_mask] + + results['optimized_thresholds']['per_language'][str(lang)] = calculate_overall_metrics( + lang_preds, lang_labels, lang_binary_preds + ) + results['optimized_thresholds']['per_language'][str(lang)]['sample_count'] = int(lang_mask.sum()) + + # Calculate per-class metrics with optimized thresholds + logger.info("Computing per-class metrics with optimized thresholds...") + for i, class_name in enumerate(tqdm(toxicity_types, desc="Class metrics (optimized)", ncols=100)): + opt_threshold = thresholds['global'][class_name]['threshold'] + results['optimized_thresholds']['per_class'][class_name] = calculate_class_metrics( + labels[:, i], + predictions[:, i], + binary_predictions_opt[:, i], + opt_threshold + ) + + # Store the thresholds used + results['thresholds'] = thresholds + + return results + +def calculate_overall_metrics(predictions, labels, binary_predictions): + """Calculate overall metrics for multi-label classification""" + metrics = {} + + # AUC scores (threshold independent) + try: + metrics['auc_macro'] = roc_auc_score(labels, predictions, average='macro') + metrics['auc_weighted'] = roc_auc_score(labels, predictions, average='weighted') + except ValueError: + # Handle case where a class has no positive samples + metrics['auc_macro'] = 0.0 + metrics['auc_weighted'] = 0.0 + + # Precision, recall, F1 (threshold dependent) + precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( + labels, binary_predictions, average='macro', zero_division=1 + ) + precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( + labels, binary_predictions, average='weighted', zero_division=1 + ) + + metrics.update({ + 'precision_macro': precision_macro, + 'precision_weighted': precision_weighted, + 'recall_macro': recall_macro, + 'recall_weighted': recall_weighted, + 'f1_macro': f1_macro, + 'f1_weighted': f1_weighted + }) + + # Hamming loss + metrics['hamming_loss'] = hamming_loss(labels, binary_predictions) + + # Exact match + metrics['exact_match'] = accuracy_score(labels, binary_predictions) + + return metrics + +def calculate_class_metrics(labels, predictions, binary_predictions, threshold): + """Calculate metrics for a single class""" + # Handle case where there are no positive samples + if labels.sum() == 0: + return { + 'auc': 0.0, + 'threshold': threshold, + 'precision': 1.0 if binary_predictions.sum() == 0 else 0.0, + 'recall': 1.0, # All true negatives were correctly identified + 'f1': 1.0 if binary_predictions.sum() == 0 else 0.0, + 'support': 0, + 'brier': brier_score_loss(labels, predictions), + 'true_positives': 0, + 'false_positives': int(binary_predictions.sum()), + 'true_negatives': int((1 - binary_predictions).sum()), + 'false_negatives': 0 + } + + try: + auc = roc_auc_score(labels, predictions) + except ValueError: + auc = 0.0 + + # Calculate metrics with zero_division=1 + precision = precision_score(labels, binary_predictions, zero_division=1) + recall = recall_score(labels, binary_predictions, zero_division=1) + f1 = f1_score(labels, binary_predictions, zero_division=1) + + metrics = { + 'auc': auc, + 'threshold': threshold, + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'support': int(labels.sum()), + 'brier': brier_score_loss(labels, predictions) + } + + # Confusion matrix metrics + tn, fp, fn, tp = confusion_matrix(labels, binary_predictions).ravel() + metrics.update({ + 'true_positives': int(tp), + 'false_positives': int(fp), + 'true_negatives': int(tn), + 'false_negatives': int(fn) + }) + + return metrics + +def save_results(results, predictions, labels, langs, output_dir): + """Save evaluation results and plots""" + os.makedirs(output_dir, exist_ok=True) + + # Save detailed metrics + with open(os.path.join(output_dir, 'evaluation_results.json'), 'w') as f: + json.dump(results, f, indent=2) + + # Save predictions for further analysis + np.savez_compressed( + os.path.join(output_dir, 'predictions.npz'), + predictions=predictions, + labels=labels, + langs=langs + ) + + # Log summary of results + logger.info("\nResults Summary:") + logger.info("\nDefault Threshold (0.5):") + logger.info(f"Macro F1: {results['default_thresholds']['overall']['f1_macro']:.3f}") + logger.info(f"Weighted F1: {results['default_thresholds']['overall']['f1_weighted']:.3f}") + + logger.info("\nOptimized Thresholds:") + logger.info(f"Macro F1: {results['optimized_thresholds']['overall']['f1_macro']:.3f}") + logger.info(f"Weighted F1: {results['optimized_thresholds']['overall']['f1_weighted']:.3f}") + + # Log threshold comparison + if 'thresholds' in results: + logger.info("\nOptimal Thresholds:") + for class_name, data in results['thresholds']['global'].items(): + logger.info(f"{class_name:>12}: {data['threshold']:.3f} (F1: {data['f1_score']:.3f})") + +def plot_metrics(results, output_dir, predictions=None, labels=None): + """Generate visualization plots comparing default vs optimized thresholds""" + plots_dir = os.path.join(output_dir, 'plots') + os.makedirs(plots_dir, exist_ok=True) + + # Plot comparison of metrics between default and optimized thresholds + if results.get('default_thresholds') and results.get('optimized_thresholds'): + plt.figure(figsize=(15, 8)) + + # Get metrics to compare + metrics = ['precision_macro', 'recall_macro', 'f1_macro'] + default_values = [results['default_thresholds']['overall'][m] for m in metrics] + optimized_values = [results['optimized_thresholds']['overall'][m] for m in metrics] + + x = np.arange(len(metrics)) + width = 0.35 + + plt.bar(x - width/2, default_values, width, label='Default Threshold (0.5)') + plt.bar(x + width/2, optimized_values, width, label='Optimized Thresholds') + + plt.ylabel('Score') + plt.title('Comparison of Default vs Optimized Thresholds') + plt.xticks(x, [m.replace('_', ' ').title() for m in metrics]) + plt.legend() + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(os.path.join(plots_dir, 'threshold_comparison.png')) + plt.close() + + # Plot per-class comparison + plt.figure(figsize=(15, 8)) + toxicity_types = list(results['default_thresholds']['per_class'].keys()) + + default_f1 = [results['default_thresholds']['per_class'][c]['f1'] for c in toxicity_types] + optimized_f1 = [results['optimized_thresholds']['per_class'][c]['f1'] for c in toxicity_types] + + x = np.arange(len(toxicity_types)) + width = 0.35 + + plt.bar(x - width/2, default_f1, width, label='Default Threshold (0.5)') + plt.bar(x + width/2, optimized_f1, width, label='Optimized Thresholds') + + plt.ylabel('F1 Score') + plt.title('Per-Class F1 Score Comparison') + plt.xticks(x, toxicity_types, rotation=45) + plt.legend() + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(os.path.join(plots_dir, 'per_class_comparison.png')) + plt.close() + +def main(): + parser = argparse.ArgumentParser(description='Evaluate toxic comment classifier') + parser.add_argument('--model_path', type=str, + default='weights/toxic_classifier_xlm-roberta-large', + help='Path to model directory containing checkpoints') + parser.add_argument('--checkpoint', type=str, + help='Specific checkpoint to evaluate (e.g., checkpoint_epoch05_20240213). If not specified, uses latest.') + parser.add_argument('--test_file', type=str, default='dataset/split/val.csv', + help='Path to test dataset') + parser.add_argument('--batch_size', type=int, default=64, + help='Batch size for evaluation') + parser.add_argument('--output_dir', type=str, default='evaluation_results', + help='Base directory to save results') + parser.add_argument('--num_workers', type=int, default=16, + help='Number of workers for data loading') + parser.add_argument('--cache_dir', type=str, default='cached_data', + help='Directory to store cached tokenized data') + parser.add_argument('--force_retokenize', action='store_true', + help='Force retokenization even if cache exists') + parser.add_argument('--prefetch_factor', type=int, default=2, + help='Number of batches to prefetch per worker') + parser.add_argument('--max_length', type=int, default=128, + help='Maximum sequence length for tokenization') + parser.add_argument('--gc_frequency', type=int, default=500, + help='Frequency of garbage collection') + parser.add_argument('--label_columns', nargs='+', + default=['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'], + help='List of label column names') + + args = parser.parse_args() + + # Create timestamped directory for this evaluation run + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + eval_dir = os.path.join(args.output_dir, f"eval_{timestamp}") + os.makedirs(eval_dir, exist_ok=True) + + # Save evaluation parameters + eval_params = { + 'timestamp': timestamp, + 'model_path': args.model_path, + 'checkpoint': args.checkpoint, + 'test_file': args.test_file, + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'cache_dir': args.cache_dir, + 'force_retokenize': args.force_retokenize, + 'prefetch_factor': args.prefetch_factor, + 'max_length': args.max_length, + 'gc_frequency': args.gc_frequency, + 'label_columns': args.label_columns + } + with open(os.path.join(eval_dir, 'eval_params.json'), 'w') as f: + json.dump(eval_params, f, indent=2) + + try: + # Load model + print("Loading multi-language toxic comment classifier model...") + model, tokenizer, device = load_model(args.model_path) + + if model is None: + return + + # Load test data + print("\nLoading test dataset...") + test_df = pd.read_csv(args.test_file) + print(f"Loaded {len(test_df):,} test samples") + + # Verify label columns exist in the DataFrame + missing_columns = [col for col in args.label_columns if col not in test_df.columns] + if missing_columns: + raise ValueError(f"Missing label columns in dataset: {missing_columns}") + + # Create test dataset + test_dataset = ToxicDataset( + test_df, + tokenizer, + args + ) + + # Configure DataLoader with optimized settings + test_loader = DataLoader( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + prefetch_factor=args.prefetch_factor, + persistent_workers=True if args.num_workers > 0 else False, + drop_last=False + ) + + # Evaluate model + results = evaluate_model(model, test_loader, device, eval_dir) + + print(f"\nEvaluation complete! Results saved to {eval_dir}") + return results + + except Exception as e: + print(f"Error during evaluation: {str(e)}") + raise + + finally: + # Cleanup + plt.close('all') + gc.collect() + torch.cuda.empty_cache() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model/hyperparameter_tuning.py b/model/hyperparameter_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d258389b06cf635ef78e3c616d747cb2762ddc --- /dev/null +++ b/model/hyperparameter_tuning.py @@ -0,0 +1,261 @@ +import optuna +from optuna.samplers import TPESampler +from optuna.pruners import MedianPruner +import wandb +import pandas as pd +from model.train import train, init_model, create_dataloaders, ToxicDataset +from model.training_config import TrainingConfig +from transformers import XLMRobertaTokenizer +import json +import torch + +def load_dataset(file_path: str): + """Load and prepare dataset""" + df = pd.read_csv(file_path) + tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + config = TrainingConfig() + return ToxicDataset(df, tokenizer, config) + +class HyperparameterTuner: + def __init__(self, train_dataset, val_dataset, n_trials=10): + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.n_trials = n_trials + + # Make pruning more aggressive + self.study = optuna.create_study( + direction="maximize", + sampler=TPESampler(seed=42), + pruner=MedianPruner( + n_startup_trials=2, + n_warmup_steps=2, + interval_steps=1 + ) + ) + + def objective(self, trial): + """Objective function for Optuna optimization with optimal ranges""" + # Define hyperparameter search space with optimal ranges + config_params = { + # Fixed architecture parameters + "model_name": "xlm-roberta-large", + "hidden_size": 1024, # Fixed to original + "num_attention_heads": 16, # Fixed to original + + # Optimized ranges based on trials + "lr": trial.suggest_float("lr", 1e-5, 5e-5, log=True), # Best range from trial-8/4 + "batch_size": trial.suggest_categorical("batch_size", [32, 64]), # Top performers + "model_dropout": trial.suggest_float("model_dropout", 0.3, 0.45), # Trial-8's 0.445 effective + "weight_decay": trial.suggest_float("weight_decay", 0.01, 0.03), # Best regularization + "grad_accum_steps": trial.suggest_int("grad_accum_steps", 1, 4), # Keep for throughput optimization + + # Fixed training parameters + "epochs": 2, + "mixed_precision": "bf16", + "max_length": 128, + "fp16": False, + "distributed": False, + "world_size": 1, + "num_workers": 12, + "activation_checkpointing": True, + "tensor_float_32": True, + "gc_frequency": 500 + } + + # Create config + config = TrainingConfig(**config_params) + + # Initialize wandb for this trial with better metadata + wandb.init( + project="toxic-classification-hparam-tuning", + name=f"trial-{trial.number}", + config={ + **config_params, + 'trial_number': trial.number, + 'pruner': str(trial.study.pruner), + 'sampler': str(trial.study.sampler) + }, + reinit=True, + tags=['hyperparameter-optimization', f'trial-{trial.number}'] + ) + + try: + # Create model and dataloaders + model = init_model(config) + train_loader, val_loader = create_dataloaders( + self.train_dataset, + self.val_dataset, + config + ) + + # Train and get metrics + metrics = train(model, train_loader, val_loader, config) + + # Log detailed metrics + wandb.log({ + 'final_val_auc': metrics['val/auc'], + 'final_val_loss': metrics['val/loss'], + 'final_train_loss': metrics['train/loss'], + 'peak_gpu_memory': torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, + 'trial_completed': True + }) + + # Report intermediate values for pruning + trial.report(metrics['val/auc'], step=config.epochs) + + # Handle pruning + if trial.should_prune(): + wandb.log({'pruned': True}) + raise optuna.TrialPruned() + + return metrics['val/auc'] + + except Exception as e: + wandb.log({ + 'error': str(e), + 'trial_failed': True + }) + print(f"Trial failed: {str(e)}") + raise optuna.TrialPruned() + + finally: + # Cleanup + if 'model' in locals(): + del model + torch.cuda.empty_cache() + wandb.finish() + + def run_optimization(self): + """Run the hyperparameter optimization""" + print("Starting hyperparameter optimization...") + print("Search space:") + print(" - Learning rate: 1e-5 to 5e-5") + print(" - Batch size: [32, 64]") + print(" - Dropout: 0.3 to 0.45") + print(" - Weight decay: 0.01 to 0.03") + print(" - Gradient accumulation steps: 1 to 4") + print("\nFixed parameters:") + print(" - Hidden size: 1024 (original)") + print(" - Attention heads: 16 (original)") + + try: + self.study.optimize( + self.objective, + n_trials=self.n_trials, + timeout=None, # No timeout + callbacks=[self._log_trial] + ) + + # Print optimization results + print("\nBest trial:") + best_trial = self.study.best_trial + print(f" Value: {best_trial.value:.4f}") + print(" Params:") + for key, value in best_trial.params.items(): + print(f" {key}: {value}") + + # Save study results with more details + self._save_study_results() + + except KeyboardInterrupt: + print("\nOptimization interrupted by user.") + self._save_study_results() # Save results even if interrupted + except Exception as e: + print(f"Optimization failed: {str(e)}") + raise + + def _log_trial(self, study, trial): + """Callback to log trial results with enhanced metrics""" + if trial.value is not None: + metrics = { + "best_auc": study.best_value, + "trial_auc": trial.value, + "trial_number": trial.number, + **trial.params + } + + # Add optimization progress metrics + if len(study.trials) > 1: + metrics.update({ + "optimization_progress": { + "trials_completed": len(study.trials), + "improvement_rate": (study.best_value - study.trials[0].value) / len(study.trials), + "best_trial_number": study.best_trial.number + } + }) + + wandb.log(metrics) + + def _save_study_results(self): + """Save optimization results with enhanced metadata""" + import joblib + from pathlib import Path + from datetime import datetime + + # Create directory if it doesn't exist + results_dir = Path("optimization_results") + results_dir.mkdir(exist_ok=True) + + # Save study object + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + study_path = results_dir / f"hparam_optimization_study_{timestamp}.pkl" + joblib.dump(self.study, study_path) + + # Save comprehensive results + results = { + "best_trial": { + "number": self.study.best_trial.number, + "value": self.study.best_value, + "params": self.study.best_trial.params + }, + "study_statistics": { + "n_trials": len(self.study.trials), + "n_completed": len([t for t in self.study.trials if t.state == optuna.trial.TrialState.COMPLETE]), + "n_pruned": len([t for t in self.study.trials if t.state == optuna.trial.TrialState.PRUNED]), + "datetime_start": self.study.trials[0].datetime_start.isoformat(), + "datetime_complete": datetime.now().isoformat() + }, + "search_space": { + "lr": {"low": 1e-5, "high": 5e-5}, + "batch_size": [32, 64], + "model_dropout": {"low": 0.3, "high": 0.45}, + "weight_decay": {"low": 0.01, "high": 0.03}, + "grad_accum_steps": {"low": 1, "high": 4} + }, + "trial_history": [ + { + "number": t.number, + "value": t.value, + "state": str(t.state), + "params": t.params if hasattr(t, 'params') else None + } + for t in self.study.trials + ] + } + + results_path = results_dir / f"optimization_results_{timestamp}.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=4) + + print(f"\nResults saved to:") + print(f" - Study: {study_path}") + print(f" - Results: {results_path}") + +def main(): + """Main function to run hyperparameter optimization""" + # Load datasets + train_dataset = load_dataset("dataset/split/train.csv") + val_dataset = load_dataset("dataset/split/val.csv") + + # Initialize tuner + tuner = HyperparameterTuner( + train_dataset=train_dataset, + val_dataset=val_dataset, + n_trials=10 + ) + + # Run optimization + tuner.run_optimization() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model/inference_optimized.py b/model/inference_optimized.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d988e19c3dcf22f2ac11f38d4396378bc15896 --- /dev/null +++ b/model/inference_optimized.py @@ -0,0 +1,165 @@ +import torch +import onnxruntime as ort +from transformers import XLMRobertaTokenizer +import numpy as np +import os + +class OptimizedToxicityClassifier: + """High-performance toxicity classifier for production""" + + def __init__(self, onnx_path=None, pytorch_path=None, device='cuda'): + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + + # Language mapping + self.lang_map = { + 'en': 0, 'ru': 1, 'tr': 2, 'es': 3, + 'fr': 4, 'it': 5, 'pt': 6 + } + + # Label names + self.label_names = [ + 'toxic', 'severe_toxic', 'obscene', + 'threat', 'insult', 'identity_hate' + ] + + # Load ONNX model if path provided + if onnx_path and os.path.exists(onnx_path): + # Use ONNX Runtime for inference + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] \ + if device == 'cuda' and 'CUDAExecutionProvider' in ort.get_available_providers() \ + else ['CPUExecutionProvider'] + + self.session = ort.InferenceSession(onnx_path, providers=providers) + self.use_onnx = True + print(f"Loaded ONNX model from {onnx_path}") + # Fall back to PyTorch if ONNX not available + elif pytorch_path: + from model.language_aware_transformer import LanguageAwareTransformer + + # Handle directory structure with checkpoint folders and 'latest' symlink + if os.path.isdir(pytorch_path): + # Check if there's a 'latest' symlink + latest_path = os.path.join(pytorch_path, 'latest') + if os.path.islink(latest_path) and os.path.exists(latest_path): + checkpoint_dir = latest_path + else: + # If no 'latest' symlink, look for checkpoint dirs and use the most recent one + checkpoint_dirs = [d for d in os.listdir(pytorch_path) if d.startswith('checkpoint_epoch')] + if checkpoint_dirs: + checkpoint_dirs.sort() # Sort to get the latest by name + checkpoint_dir = os.path.join(pytorch_path, checkpoint_dirs[-1]) + else: + raise ValueError(f"No checkpoint directories found in {pytorch_path}") + + # Look for PyTorch model files in the checkpoint directory + model_file = None + potential_files = ['pytorch_model.bin', 'model.pt', 'model.pth'] + for file in potential_files: + candidate = os.path.join(checkpoint_dir, file) + if os.path.exists(candidate): + model_file = candidate + break + + if not model_file: + raise FileNotFoundError(f"No model file found in {checkpoint_dir}") + + print(f"Using model from checkpoint: {checkpoint_dir}") + model_path = model_file + else: + # If pytorch_path is a direct file path + model_path = pytorch_path + + self.model = LanguageAwareTransformer(num_labels=6) + self.model.load_state_dict(torch.load(model_path, map_location=device)) + self.model.to(device) + self.model.eval() + self.use_onnx = False + self.device = device + print(f"Loaded PyTorch model from {model_path}") + else: + raise ValueError("Either onnx_path or pytorch_path must be provided") + + def predict(self, texts, langs=None, batch_size=8): + """ + Predict toxicity for a list of texts + + Args: + texts: List of text strings + langs: List of language codes (e.g., 'en', 'fr') + batch_size: Batch size for processing + + Returns: + List of dictionaries with toxicity predictions + """ + results = [] + + # Auto-detect or default language if not provided + if langs is None: + langs = ['en'] * len(texts) + + # Convert language codes to IDs + lang_ids = [self.lang_map.get(lang, 0) for lang in langs] + + # Process in batches + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i+batch_size] + batch_langs = lang_ids[i:i+batch_size] + + # Tokenize + inputs = self.tokenizer( + batch_texts, + padding=True, + truncation=True, + max_length=128, + return_tensors='pt' + ) + + # Get predictions + if self.use_onnx: + # ONNX inference + ort_inputs = { + 'input_ids': inputs['input_ids'].numpy(), + 'attention_mask': inputs['attention_mask'].numpy(), + 'lang_ids': np.array(batch_langs, dtype=np.int64) + } + ort_outputs = self.session.run(None, ort_inputs) + probabilities = 1 / (1 + np.exp(-ort_outputs[0])) # sigmoid + else: + # PyTorch inference + with torch.no_grad(): + inputs = {k: v.to(self.device) for k, v in inputs.items()} + lang_tensor = torch.tensor(batch_langs, dtype=torch.long, device=self.device) + outputs = self.model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'], + lang_ids=lang_tensor, + mode='inference' + ) + probabilities = outputs['probabilities'].cpu().numpy() + + # Format results + for j, (text, lang, probs) in enumerate(zip(batch_texts, langs[i:i+batch_size], probabilities)): + # Apply optimal thresholds per language + lang_thresholds = { # Increased by ~20% from original values + 'default': [0.60, 0.54, 0.60, 0.48, 0.60, 0.50] # Increased by ~20% from original values + # mapping [toxic, severe_toxic, obscene, threat, insult, identity_hate] + } + + + thresholds = lang_thresholds.get(lang, lang_thresholds['default']) + is_toxic = (probs >= np.array(thresholds)).astype(bool) + + result = { + 'text': text, + 'language': lang, + 'probabilities': { + label: float(prob) for label, prob in zip(self.label_names, probs) + }, + 'is_toxic': bool(is_toxic.any()), + 'toxic_categories': [ + self.label_names[k] for k in range(len(is_toxic)) if is_toxic[k] + ] + } + results.append(result) + + return results \ No newline at end of file diff --git a/model/language_aware_transformer.py b/model/language_aware_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9cedf5d81cc3ac503abd67633ae6a552edf77793 --- /dev/null +++ b/model/language_aware_transformer.py @@ -0,0 +1,369 @@ +# language_aware_transformer.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import XLMRobertaModel +from typing import Optional +import logging +import os +import json + +logger = logging.getLogger(__name__) + +SUPPORTED_LANGUAGES = { + 'en': 0, 'ru': 1, 'tr': 2, 'es': 3, + 'fr': 4, 'it': 5, 'pt': 6 +} + +def validate_lang_ids(lang_ids): + if not isinstance(lang_ids, torch.Tensor): + lang_ids = torch.tensor(lang_ids, dtype=torch.long) + # Use actual language count instead of hardcoded 9 + return torch.clamp(lang_ids, min=0, max=len(SUPPORTED_LANGUAGES)-1) + +class LanguageAwareClassifier(nn.Module): + def __init__(self, hidden_size=1024, num_labels=6): + super().__init__() + self.lang_embed = nn.Embedding(7, 64) # 7 languages + + # Simplified classifier layers + self.classifier = nn.Sequential( + nn.Linear(hidden_size + 64, 512), + nn.LayerNorm(512), + nn.GELU(), + nn.Linear(512, num_labels) + ) + + # Vectorized language-specific thresholds + self.lang_thresholds = nn.Parameter( + torch.ones(len(SUPPORTED_LANGUAGES), num_labels) + ) + # Initialize with small random values around 1 + nn.init.normal_(self.lang_thresholds, mean=1.0, std=0.01) + + self._init_weights() + + def _init_weights(self): + """Initialize weights with Xavier uniform""" + for module in self.classifier: + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + + def forward(self, x, lang_ids): + # Ensure lang_ids is a tensor of integers + if not isinstance(lang_ids, torch.Tensor): + lang_ids = torch.tensor(lang_ids, dtype=torch.long, device=x.device) + elif lang_ids.dtype != torch.long: + lang_ids = lang_ids.long() + + # Get language embeddings + lang_emb = self.lang_embed(lang_ids) # Shape: [batch_size, 64] + + # Concatenate features with language embeddings for classification + combined = torch.cat([x, lang_emb], dim=-1) # Shape: [batch_size, hidden_size + 64] + + # Apply simplified classifier + logits = self.classifier(combined) # Shape: [batch_size, num_labels] + + # Apply language-specific thresholds using vectorized operations + thresholds = self.lang_thresholds[lang_ids] # Shape: [batch_size, num_labels] + logits = logits * torch.sigmoid(thresholds) # Shape: [batch_size, num_labels] + + return logits + +class WeightedBCEWithLogitsLoss(nn.Module): + def __init__(self, gamma=2.0, reduction='mean'): + super().__init__() + self.gamma = gamma + self.reduction = reduction + + def forward(self, logits, targets, weights=None): + bce_loss = F.binary_cross_entropy_with_logits( + logits, targets, reduction='none' + ) + pt = torch.exp(-bce_loss) + focal_loss = (1 - pt)**self.gamma * bce_loss + if weights is not None: + focal_loss *= weights + return focal_loss.mean() + +class LanguageAwareTransformer(nn.Module): + def __init__( + self, + num_labels: int = 6, + hidden_size: int = 1024, + num_attention_heads: int = 16, + model_name: str = "xlm-roberta-large", + dropout: float = 0.0 + ): + super().__init__() + + # Validate supported languages + if not SUPPORTED_LANGUAGES: + raise ValueError("No supported languages defined") + logger.info(f"Initializing model with {len(SUPPORTED_LANGUAGES)} supported languages: {list(SUPPORTED_LANGUAGES.keys())}") + + # Load pretrained model + self.base_model = XLMRobertaModel.from_pretrained(model_name) + self.config = self.base_model.config + + # Project to custom hidden size if different from original + self.original_hidden_size = self.config.hidden_size + self.needs_projection = hidden_size != self.original_hidden_size + if self.needs_projection: + self.dim_projection = nn.Sequential( + nn.Linear(self.original_hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.GELU() + ) + + # Working hidden size + self.working_hidden_size = hidden_size if self.needs_projection else self.original_hidden_size + + # Language-aware attention components with dynamic language count + num_languages = len(SUPPORTED_LANGUAGES) + self.lang_embed = nn.Embedding(num_languages, 64) + + # Register supported languages for validation + self.register_buffer('valid_lang_ids', torch.arange(num_languages)) + + # Optimized language projection for attention bias + self.lang_proj = nn.Sequential( + nn.Linear(64, num_attention_heads * hidden_size // num_attention_heads), + nn.LayerNorm(num_attention_heads * hidden_size // num_attention_heads), + nn.Tanh() # Bounded activation for stable attention scores + ) + + # Multi-head attention with optimized head dimension + head_dim = hidden_size // num_attention_heads + self.scale = head_dim ** -0.5 # Scaling factor for attention scores + + self.q_proj = nn.Linear(hidden_size, hidden_size) + self.k_proj = nn.Linear(hidden_size, hidden_size) + self.v_proj = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(dropout) + + self.post_attention = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.GELU() + ) + + # Output classifier + self.classifier = nn.Sequential( + nn.Linear(hidden_size, 512), + nn.LayerNorm(512), + nn.GELU(), + nn.Linear(512, num_labels) + ) + + self._init_weights() + self.gradient_checkpointing = False + + def _init_weights(self): + """Initialize weights with careful scaling""" + for module in [self.lang_proj, self.q_proj, self.k_proj, self.v_proj, + self.post_attention, self.classifier]: + if isinstance(module, nn.Sequential): + for layer in module: + if isinstance(layer, nn.Linear): + # Use scaled initialization for attention projections + if layer in [self.q_proj, self.k_proj, self.v_proj]: + nn.init.normal_(layer.weight, std=0.02) + else: + nn.init.xavier_uniform_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + elif isinstance(layer, nn.LayerNorm): + nn.init.ones_(layer.weight) + nn.init.zeros_(layer.bias) + elif isinstance(module, nn.Linear): + if module in [self.q_proj, self.k_proj, self.v_proj]: + nn.init.normal_(module.weight, std=0.02) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def gradient_checkpointing_enable(self): + self.gradient_checkpointing = True + self.base_model.gradient_checkpointing_enable() + + def gradient_checkpointing_disable(self): + self.gradient_checkpointing = False + self.base_model.gradient_checkpointing_disable() + + def validate_lang_ids(self, lang_ids: torch.Tensor) -> torch.Tensor: + """ + Validate and normalize language IDs + Args: + lang_ids: Tensor of language IDs + Returns: + Validated and normalized language ID tensor + Raises: + ValueError if too many invalid IDs detected + """ + if not isinstance(lang_ids, torch.Tensor): + lang_ids = torch.tensor(lang_ids, dtype=torch.long, device=self.valid_lang_ids.device) + elif lang_ids.dtype != torch.long: + lang_ids = lang_ids.long() + + # Check for out-of-bounds IDs + invalid_mask = ~torch.isin(lang_ids, self.valid_lang_ids) + num_invalid = invalid_mask.sum().item() + + if num_invalid > 0: + invalid_ratio = num_invalid / lang_ids.numel() + if invalid_ratio > 0.1: # More than 10% invalid + raise ValueError( + f"Too many invalid language IDs detected ({num_invalid} out of {lang_ids.numel()}). " + f"Valid range is 0-{len(SUPPORTED_LANGUAGES)-1}" + ) + # Log warning and clamp invalid IDs + logger.warning( + f"Found {num_invalid} invalid language IDs. " + f"Valid range is 0-{len(SUPPORTED_LANGUAGES)-1}. " + "Invalid IDs will be clamped to valid range." + ) + lang_ids = torch.clamp(lang_ids, min=0, max=len(SUPPORTED_LANGUAGES)-1) + + return lang_ids + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: Optional[torch.Tensor] = None, + lang_ids: Optional[torch.Tensor] = None, + mode: str = 'train' + ) -> dict: + device = input_ids.device + batch_size = input_ids.size(0) + + # Handle language IDs with validation + if lang_ids is None: + lang_ids = torch.zeros(batch_size, dtype=torch.long, device=device) + + # Validate and normalize language IDs + try: + lang_ids = self.validate_lang_ids(lang_ids) + except ValueError as e: + logger.error(f"Language ID validation failed: {str(e)}") + logger.error("Falling back to default language (0)") + lang_ids = torch.zeros_like(lang_ids) + + # Base model forward pass + hidden_states = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask + ).last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + + # Check for numerical instabilities + if hidden_states.isnan().any(): + raise ValueError("NaN detected in hidden states") + if hidden_states.isinf().any(): + raise ValueError("Inf detected in hidden states") + + # Project if needed + if self.needs_projection: + hidden_states = self.dim_projection(hidden_states) + + # Generate language-aware attention bias + lang_emb = self.lang_embed(lang_ids) # [batch_size, 64] + lang_bias = self.lang_proj(lang_emb) # [batch_size, num_heads * head_dim] + + # Reshape for multi-head attention + batch_size, seq_len, hidden_size = hidden_states.shape + num_heads = self.config.num_attention_heads + head_dim = hidden_size // num_heads + + # Project queries, keys, and values + q = self.q_proj(hidden_states).view(batch_size, seq_len, num_heads, head_dim) + k = self.k_proj(hidden_states).view(batch_size, seq_len, num_heads, head_dim) + v = self.v_proj(hidden_states).view(batch_size, seq_len, num_heads, head_dim) + + # Transpose for attention computation + q = q.transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Compute attention scores with language bias + attn_bias = lang_bias.view(batch_size, num_heads, head_dim, 1) + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn_scores = attn_scores + torch.matmul(q, attn_bias).squeeze(-1).unsqueeze(-1) + + # Apply attention mask + if attention_mask is not None: + attn_scores = attn_scores.masked_fill( + ~attention_mask.bool().unsqueeze(1).unsqueeze(2), + float('-inf') + ) + + # Compute attention weights and apply to values + attn_weights = F.softmax(attn_scores, dim=-1) + attn_weights = self.dropout(attn_weights) + attention_output = torch.matmul(attn_weights, v) + + # Reshape and post-process + attention_output = attention_output.transpose(1, 2).contiguous().view( + batch_size, seq_len, hidden_size + ) + output = self.post_attention(attention_output) + + # Get logits using the [CLS] token output + logits = self.classifier(output[:, 0]) + + # Apply language-specific threshold adjustments based on statistical patterns + LANG_THRESHOLD_ADJUSTMENTS = { + 0: [0.00, 0.00, 0.00, 0.00, 0.00, 0.00], # en (baseline) + 1: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], # ru (higher insult tendency) + 2: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], # tr + 3: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], # es + 4: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], # fr + 5: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], # it + 6: [-0.02, 0.00, 0.02, 0.00, -0.03, 0.00], # pt + } + + # Get threshold adjustments for each instance in batch + if mode == 'inference': + threshold_adj = torch.tensor( + [LANG_THRESHOLD_ADJUSTMENTS[lang.item()] for lang in lang_ids], + device=logits.device + ) + # Apply adjustment to logits + logits = logits + threshold_adj + + probabilities = torch.sigmoid(logits) + + # Prepare output dictionary + result = { + 'logits': logits, + 'probabilities': probabilities + } + + # Add loss if labels are provided + if labels is not None: + loss_fct = WeightedBCEWithLogitsLoss() + result['loss'] = loss_fct(logits, labels) + + return result + + def save_pretrained(self, save_path: str): + os.makedirs(save_path, exist_ok=True) + torch.save(self.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) + + config_dict = { + 'num_labels': self.classifier[-1].out_features, + 'hidden_size': self.config.hidden_size, + 'num_attention_heads': self.config.num_attention_heads, + 'model_name': self.config.name_or_path, + 'dropout': self.dropout.p + } + + with open(os.path.join(save_path, 'config.json'), 'w') as f: + json.dump(config_dict, f, indent=2) diff --git a/model/predict.py b/model/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..21a8afbb53c35db139874f6d0754f59a0905e351 --- /dev/null +++ b/model/predict.py @@ -0,0 +1,416 @@ +import torch +from model.language_aware_transformer import LanguageAwareTransformer +from transformers import XLMRobertaTokenizer +import os +import re +import json +from pathlib import Path +import logging +from langdetect import detect, DetectorFactory +from langdetect.lang_detect_exception import LangDetectException +import sys +import locale +import io + +# Force UTF-8 encoding for stdin/stdout +if sys.platform == 'win32': + # Windows-specific handling + import msvcrt + sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') + # Set console to UTF-8 mode + os.system('chcp 65001') +else: + # Unix-like systems + if sys.stdout.encoding != 'utf-8': + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + if sys.stdin.encoding != 'utf-8': + sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') + +# Set up logging with UTF-8 encoding +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger(__name__) + +# Ensure reproducibility with langdetect +DetectorFactory.seed = 0 + +SUPPORTED_LANGUAGES = { + 'en': 0, 'ru': 1, 'tr': 2, 'es': 3, + 'fr': 4, 'it': 5, 'pt': 6 +} + +# Default thresholds optimized on validation set +DEFAULT_THRESHOLDS = { + 'toxic': 0.80, # Optimized for general toxicity + 'severe_toxic': 0.45, # Lower to catch serious cases + 'obscene': 0.48, # Balanced for precision/recall + 'threat': 0.42, # Lower to catch potential threats + 'insult': 0.70, # Balanced for common cases + 'identity_hate': 0.43 # Lower to catch hate speech +} + +# Unicode ranges for different scripts +UNICODE_RANGES = { + 'ru': [ + (0x0400, 0x04FF), # Cyrillic + (0x0500, 0x052F), # Cyrillic Supplement + ], + 'tr': [ + (0x011E, 0x011F), # Ğ ğ + (0x0130, 0x0131), # İ ı + (0x015E, 0x015F), # Ş ş + ], + 'es': [ + (0x00C1, 0x00C1), # Á + (0x00C9, 0x00C9), # É + (0x00CD, 0x00CD), # Í + (0x00D1, 0x00D1), # Ñ + (0x00D3, 0x00D3), # Ó + (0x00DA, 0x00DA), # Ú + (0x00DC, 0x00DC), # Ü + ], + 'fr': [ + (0x00C0, 0x00C6), # À-Æ + (0x00C8, 0x00CB), # È-Ë + (0x00CC, 0x00CF), # Ì-Ï + (0x00D2, 0x00D6), # Ò-Ö + (0x0152, 0x0153), # Œ œ + ], + 'it': [ + (0x00C0, 0x00C0), # À + (0x00C8, 0x00C8), # È + (0x00C9, 0x00C9), # É + (0x00CC, 0x00CC), # Ì + (0x00D2, 0x00D2), # Ò + (0x00D9, 0x00D9), # Ù + ], + 'pt': [ + (0x00C0, 0x00C3), # À-à + (0x00C7, 0x00C7), # Ç + (0x00C9, 0x00CA), # É-Ê + (0x00D3, 0x00D5), # Ó-Õ + ] +} + +def load_model(model_path): + """Load the trained model and tokenizer""" + try: + # Convert to absolute Path object + model_dir = Path(model_path).absolute() + + if model_dir.is_dir(): + # Check for 'latest' symlink first + latest_link = model_dir / 'latest' + if latest_link.exists() and latest_link.is_symlink(): + # Get the target of the symlink + target = latest_link.readlink() + # If target is absolute, use it directly + if target.is_absolute(): + model_dir = target + else: + # If target is relative, resolve it relative to the symlink's directory + model_dir = (latest_link.parent / target).resolve() + logger.info(f"Using latest checkpoint: {model_dir}") + else: + # Find most recent checkpoint + checkpoints = sorted([ + d for d in model_dir.iterdir() + if d.is_dir() and d.name.startswith('checkpoint_epoch') + ]) + if checkpoints: + model_dir = checkpoints[-1] + logger.info(f"Using most recent checkpoint: {model_dir}") + else: + logger.info("No checkpoints found, using base directory") + + logger.info(f"Loading model from: {model_dir}") + + # Verify the directory exists + if not model_dir.exists(): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + # Initialize the custom model architecture + model = LanguageAwareTransformer( + num_labels=6, + hidden_size=1024, + num_attention_heads=16, + model_name='xlm-roberta-large' + ) + + # Load the trained weights + weights_path = model_dir / 'pytorch_model.bin' + if not weights_path.exists(): + raise FileNotFoundError(f"Model weights not found at {weights_path}") + + state_dict = torch.load(weights_path) + model.load_state_dict(state_dict) + logger.info("Model weights loaded successfully") + + # Load base XLM-RoBERTa tokenizer directly + logger.info("Loading XLM-RoBERTa tokenizer...") + tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + + # Load training metadata if available + metadata_path = model_dir / 'metadata.json' + if metadata_path.exists(): + with open(metadata_path) as f: + metadata = json.load(f) + logger.info(f"Loaded checkpoint metadata: Epoch {metadata.get('epoch', 'unknown')}") + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model.eval() + + return model, tokenizer, device + + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + logger.error("\nPlease ensure that:") + logger.error("1. You have trained the model first using train.py") + logger.error("2. The model weights are saved in the correct location") + logger.error("3. You have sufficient permissions to access the model files") + return None, None, None + +def adjust_thresholds(thresholds): + """ + Adjust thresholds based on recommendations to reduce overflagging + """ + if not thresholds: + return thresholds + + adjusted = thresholds.copy() + # Adjust thresholds for each language + for lang_id in adjusted: + for category, recommended in DEFAULT_THRESHOLDS.items(): + if category in adjusted[lang_id]: + # Only increase threshold if recommended is higher + adjusted[lang_id][category] = max(adjusted[lang_id][category], recommended) + + return adjusted + +def analyze_unicode_ranges(text): + """Analyze text for characters in language-specific Unicode ranges""" + scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()} + + for char in text: + code = ord(char) + for lang, ranges in UNICODE_RANGES.items(): + for start, end in ranges: + if start <= code <= end: + scores[lang] += 1 + + return scores + +def analyze_tokenizer_stats(text, tokenizer): + """Analyze tokenizer statistics for language detection""" + # Get tokenizer output + tokens = tokenizer.tokenize(text) + + # Count language-specific token patterns + scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()} + + # Analyze token patterns + for token in tokens: + token = token.lower() + # Check for language-specific subwords + if 'en' in token or '_en' in token: + scores['en'] += 1 + elif 'ru' in token or '_ru' in token: + scores['ru'] += 1 + elif 'tr' in token or '_tr' in token: + scores['tr'] += 1 + elif 'es' in token or '_es' in token: + scores['es'] += 1 + elif 'fr' in token or '_fr' in token: + scores['fr'] += 1 + elif 'it' in token or '_it' in token: + scores['it'] += 1 + elif 'pt' in token or '_pt' in token: + scores['pt'] += 1 + + return scores + +def detect_language(text, tokenizer): + """ + Enhanced language detection using langdetect with multiple fallback methods: + 1. Primary: langdetect library + 2. Fallback 1: ASCII analysis for English + 3. Fallback 2: Unicode range analysis + 4. Fallback 3: Tokenizer statistics + """ + try: + # Clean text + text = text.strip() + + # If empty or just punctuation, default to English + if not text or not re.search(r'\w', text): + return SUPPORTED_LANGUAGES['en'] + + # Primary method: Use langdetect + try: + detected_code = detect(text) + # Map some common language codes that might differ + lang_mapping = { + 'eng': 'en', + 'rus': 'ru', + 'tur': 'tr', + 'spa': 'es', + 'fra': 'fr', + 'ita': 'it', + 'por': 'pt' + } + detected_code = lang_mapping.get(detected_code, detected_code) + + if detected_code in SUPPORTED_LANGUAGES: + return SUPPORTED_LANGUAGES[detected_code] + except LangDetectException: + pass # Continue to fallback methods + + # Fallback 1: If text is ASCII only, likely English + if all(ord(c) < 128 for c in text): + return SUPPORTED_LANGUAGES['en'] + + # Fallback 2 & 3: Combine Unicode analysis and tokenizer statistics + unicode_scores = analyze_unicode_ranges(text) + tokenizer_scores = analyze_tokenizer_stats(text, tokenizer) + + # Combine scores with weights + final_scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()} + for lang in SUPPORTED_LANGUAGES.keys(): + final_scores[lang] = ( + unicode_scores[lang] * 2 + # Unicode ranges have higher weight + tokenizer_scores[lang] + ) + + # Get language with highest score + if any(score > 0 for score in final_scores.values()): + detected_lang = max(final_scores.items(), key=lambda x: x[1])[0] + return SUPPORTED_LANGUAGES[detected_lang] + + # Default to English if no clear match + return SUPPORTED_LANGUAGES['en'] + + except Exception as e: + logger.warning(f"Language detection failed ({str(e)}). Using English.") + return SUPPORTED_LANGUAGES['en'] + +def predict_toxicity(text, model, tokenizer, device): + """Predict toxicity labels for a given text""" + # Detect language + lang_id = detect_language(text, tokenizer) + + # Tokenize text + encoding = tokenizer( + text, + max_length=128, + padding='max_length', + truncation=True, + return_tensors='pt' + ) + + # Move to device + input_ids = encoding['input_ids'].to(device) + attention_mask = encoding['attention_mask'].to(device) + + # Get predictions + with torch.no_grad(): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + predictions = outputs['probabilities'] + + # Convert to probabilities + probabilities = predictions[0].cpu().numpy() + + # Labels for toxicity types + labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Create results dictionary using optimized thresholds + results = {} + for label, prob in zip(labels, probabilities): + threshold = DEFAULT_THRESHOLDS.get(label, 0.5) # Use optimized defaults + results[label] = { + 'probability': float(prob), + 'is_toxic': prob > threshold, + 'threshold': threshold + } + + return results, lang_id + +def main(): + # Load model + print("Loading model...") + model_path = 'weights/toxic_classifier_xlm-roberta-large/latest' + model, tokenizer, device = load_model(model_path) + + if model is None or tokenizer is None: + return + + while True: + try: + # Get input text with proper Unicode handling + print("\nEnter text to analyze (or 'q' to quit):") + try: + if sys.platform == 'win32': + # Windows-specific input handling + text = sys.stdin.buffer.readline().decode('utf-8').strip() + else: + text = input().strip() + except UnicodeDecodeError: + # Fallback to latin-1 if UTF-8 fails + if sys.platform == 'win32': + text = sys.stdin.buffer.readline().decode('latin-1').strip() + else: + text = sys.stdin.buffer.readline().decode('latin-1').strip() + + if text.lower() == 'q': + break + + if not text: + print("Please enter some text to analyze.") + continue + + # Make prediction + print("\nAnalyzing text...") + predictions, lang_id = predict_toxicity(text, model, tokenizer, device) + + # Get language name + lang_name = [k for k, v in SUPPORTED_LANGUAGES.items() if v == lang_id][0] + + # Print results + print("\nResults:") + print("-" * 50) + print(f"Text: {text}") + print(f"Detected Language: {lang_name}") + print("\nToxicity Analysis:") + + any_toxic = False + for label, result in predictions.items(): + if result['is_toxic']: + any_toxic = True + print(f"- {label}: {result['probability']:.2%} (threshold: {result['threshold']:.2%}) ⚠️") + + # Print non-toxic results with lower emphasis + print("\nOther categories:") + for label, result in predictions.items(): + if not result['is_toxic']: + print(f"- {label}: {result['probability']:.2%} (threshold: {result['threshold']:.2%}) ✓") + + # Overall assessment + print("\nOverall Assessment:") + if any_toxic: + print("⚠️ This text contains toxic content") + else: + print("✅ This text appears to be non-toxic") + + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + print("\nAn unexpected error occurred. Please try again.") + continue + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model/train.py b/model/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3f77205d40bef73b183b6ccf40ba280ef2d107d3 --- /dev/null +++ b/model/train.py @@ -0,0 +1,720 @@ +# train.py +import pandas as pd +import torch +import logging +import os +import gc +import wandb +from datetime import datetime +import signal +import atexit +import sys +from pathlib import Path +import numpy as np +import warnings +import json +from tqdm import tqdm +import torch.nn as nn +import torch.nn.functional as F +import time + +from transformers import ( + XLMRobertaTokenizer +) +from torch.utils.data import DataLoader +from model.evaluation.evaluate import ToxicDataset +from model.training_config import MetricsTracker, TrainingConfig +from model.data.sampler import MultilabelStratifiedSampler +from model.language_aware_transformer import LanguageAwareTransformer +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(f'logs/train_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +# Set environment variables if not already set +os.environ['TF_CPP_MIN_LOG_LEVEL'] = os.environ.get('TF_CPP_MIN_LOG_LEVEL', '2') +warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0") +warnings.filterwarnings("ignore", message="AVX2 detected") + +# Initialize global variables with None +_model = None +_optimizer = None +_scheduler = None +_cleanup_handlers = [] + +def register_cleanup(handler): + """Register cleanup handlers that will be called on exit""" + _cleanup_handlers.append(handler) + +def cleanup(): + """Cleanup function to be called on exit""" + global _model, _optimizer, _scheduler + + print("\nPerforming cleanup...") + + for handler in _cleanup_handlers: + try: + handler() + except Exception as e: + print(f"Warning: Cleanup handler failed: {str(e)}") + + if torch.cuda.is_available(): + try: + torch.cuda.empty_cache() + except Exception as e: + print(f"Warning: Could not clear CUDA cache: {str(e)}") + + try: + if _model is not None: + del _model + if _optimizer is not None: + del _optimizer + if _scheduler is not None: + del _scheduler + except Exception as e: + print(f"Warning: Error during cleanup: {str(e)}") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +# Register cleanup handlers +atexit.register(cleanup) + +def signal_handler(signum, frame): + print(f"\nReceived signal {signum}. Cleaning up...") + cleanup() + sys.exit(0) + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def init_model(config): + """Initialize model with error handling""" + global _model + + try: + _model = LanguageAwareTransformer( + num_labels=config.num_labels, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + model_name=config.model_name, + dropout=config.model_dropout + ) + + assert config.hidden_size == 1024, "XLM-R hidden size must be 1024" + assert _model.base_model.config.num_attention_heads == 16, "Head count mismatch" + + if config.freeze_layers > 0: + for param in list(_model.base_model.parameters())[:8]: + param.requires_grad = False + + assert not any([p.requires_grad for p in _model.base_model.parameters()][:8]), "First 8 layers should be frozen" + + # Enhanced gradient checkpointing setup + if config.activation_checkpointing: + logger.info("Enabling gradient checkpointing for memory efficiency") + _model.gradient_checkpointing = True + _model.base_model.gradient_checkpointing_enable() + _model.base_model._set_gradient_checkpointing(enable=True) + + # Verify checkpointing is enabled + assert _model.base_model.is_gradient_checkpointing, "Gradient checkpointing failed to enable" + + _model = _model.to(config.device) + return _model + + except Exception as e: + logger.error(f"Fatal error initializing model: {str(e)}") + raise + +def get_grad_stats(model): + """Calculate gradient statistics for monitoring""" + try: + grad_norms = [] + grad_means = [] + grad_maxs = [] + grad_mins = [] + param_names = [] + + for name, param in model.named_parameters(): + if param.grad is not None: + grad = param.grad + grad_norm = grad.norm().item() + grad_norms.append(grad_norm) + grad_means.append(grad.mean().item()) + grad_maxs.append(grad.max().item()) + grad_mins.append(grad.min().item()) + param_names.append(name) + + if grad_norms: + return { + 'grad/max_norm': max(grad_norms), + 'grad/min_norm': min(grad_norms), + 'grad/mean_norm': sum(grad_norms) / len(grad_norms), + 'grad/max_value': max(grad_maxs), + 'grad/min_value': min(grad_mins), + 'grad/mean_value': sum(grad_means) / len(grad_means), + 'grad/largest_layer': param_names[grad_norms.index(max(grad_norms))], + 'grad/smallest_layer': param_names[grad_norms.index(min(grad_norms))] + } + return {} + except Exception as e: + logger.warning(f"Error calculating gradient stats: {str(e)}") + return {} + +class LanguageAwareFocalLoss(nn.Module): + def __init__(self, reduction='mean'): + super().__init__() + self.reduction = reduction + + def forward(self, inputs, targets, lang_weights=None, alpha=None, gamma=None): + """ + Compute focal loss with language-aware weighting and per-class parameters + Args: + inputs: Model predictions [batch_size, num_classes] + targets: Target labels [batch_size, num_classes] + lang_weights: Optional language weights [batch_size, num_classes] + alpha: Optional class-wise weight factor [num_classes] or [batch_size, num_classes] + gamma: Optional focusing parameter [num_classes] or [batch_size, num_classes] + """ + if alpha is None: + alpha = torch.full_like(inputs, 0.25) + if gamma is None: + gamma = torch.full_like(inputs, 2.0) + + # Ensure alpha and gamma have correct shape [batch_size, num_classes] + if alpha.dim() == 1: + alpha = alpha.unsqueeze(0).expand(inputs.size(0), -1) + if gamma.dim() == 1: + gamma = gamma.unsqueeze(0).expand(inputs.size(0), -1) + + # Compute binary cross entropy without reduction + bce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction='none' + ) + + # Compute probabilities for focusing + pt = torch.exp(-bce_loss) # [batch_size, num_classes] + + # Compute focal weights with per-class gamma + focal_weights = (1 - pt) ** gamma # [batch_size, num_classes] + + # Apply alpha weighting per-class + weighted_focal_loss = alpha * focal_weights * bce_loss + + # Apply language-specific weights if provided + if lang_weights is not None: + weighted_focal_loss = weighted_focal_loss * lang_weights + + # Reduce if needed + if self.reduction == 'mean': + return weighted_focal_loss.mean() + elif self.reduction == 'sum': + return weighted_focal_loss.sum() + return weighted_focal_loss + +def training_step(batch, model, optimizer, scheduler, config, scaler, batch_idx): + """Execute a single training step with gradient accumulation""" + # Move batch to device + batch = {k: v.to(config.device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + + # Calculate language weights and focal parameters + lang_weights = None + alpha = None + gamma = None + + if hasattr(config, 'lang_weights') and config.lang_weights is not None: + weight_dict = config.lang_weights.get_weights_for_batch( + [lang.item() for lang in batch['lang']], + batch['labels'], + config.device + ) + lang_weights = weight_dict['weights'] # [batch_size, num_classes] + alpha = weight_dict['alpha'] # [num_classes] + gamma = weight_dict['gamma'] # [num_classes] + else: + # Default focal parameters if no language weights + num_classes = batch['labels'].size(1) + alpha = torch.full((num_classes,), 0.25, device=config.device) + gamma = torch.full((num_classes,), 2.0, device=config.device) + + # Forward pass + with config.get_autocast_context(): + outputs = model( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + labels=batch['labels'], + lang_ids=batch['lang'] + ) + + # Calculate loss with per-class focal parameters + loss_fct = LanguageAwareFocalLoss() + loss = loss_fct( + outputs['logits'], + batch['labels'].float(), + lang_weights=lang_weights, + alpha=alpha, + gamma=gamma + ) + outputs['loss'] = loss + + # Check for numerical instability + if torch.isnan(loss).any() or torch.isinf(loss).any(): + logger.error(f"Numerical instability detected! Loss: {loss.item()}") + logger.error(f"Batch stats - input_ids shape: {batch['input_ids'].shape}, labels shape: {batch['labels'].shape}") + if lang_weights is not None: + logger.error(f"Weights stats - min: {lang_weights.min():.3f}, max: {lang_weights.max():.3f}") + logger.error(f"Focal params - gamma range: [{gamma.min():.3f}, {gamma.max():.3f}], alpha range: [{alpha.min():.3f}, {alpha.max():.3f}]") + optimizer.zero_grad() + return None + + # Scale loss for gradient accumulation + if config.grad_accum_steps > 1: + loss = loss / config.grad_accum_steps + + # Backward pass with scaled loss + scaler.scale(loss).backward() + + # Only update weights after accumulating enough gradients + if (batch_idx + 1) % config.grad_accum_steps == 0: + # Log gradient stats before clipping + if batch_idx % 100 == 0: + grad_stats = get_grad_stats(model) + if grad_stats: + logger.debug("Gradient stats before clipping:") + for key, value in grad_stats.items(): + logger.debug(f"{key}: {value}") + + # Gradient clipping + if config.max_grad_norm > 0: + # Unscale gradients before clipping + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.max_grad_norm + ) + if grad_norm.isnan() or grad_norm.isinf(): + logger.warning(f"Gradient norm is {grad_norm}, skipping optimizer step") + optimizer.zero_grad() + return loss.item() * config.grad_accum_steps # Return unscaled loss for logging + + # Optimizer step with scaler + scaler.step(optimizer) + scaler.update() + + # Zero gradients after optimizer step + optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad() + + # Step scheduler after optimization + scheduler.step() + + # Log gradient stats after update + if batch_idx % 100 == 0: + grad_stats = get_grad_stats(model) + if grad_stats: + logger.debug("Gradient stats after update:") + for key, value in grad_stats.items(): + logger.debug(f"{key}: {value}") + + # Return the original (unscaled) loss for logging + return loss.item() * config.grad_accum_steps if config.grad_accum_steps > 1 else loss.item() + +def save_checkpoint(model, optimizer, scheduler, metrics, config, epoch): + """Save model checkpoint with versioning and timestamps""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Create base checkpoint directory + base_dir = Path('weights/toxic_classifier_xlm-roberta-large') + base_dir.mkdir(parents=True, exist_ok=True) + + # Create versioned checkpoint directory + checkpoint_dir = base_dir / f"checkpoint_epoch{epoch:02d}_{timestamp}" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving checkpoint to {checkpoint_dir}") + + try: + # Save model state + model_save_path = checkpoint_dir / 'pytorch_model.bin' + torch.save(model.state_dict(), model_save_path) + logger.info(f"Saved model state to {model_save_path}") + + # Save training state + training_state = { + 'epoch': epoch, + 'optimizer_state': optimizer.state_dict(), + 'scheduler_state': scheduler.state_dict(), + 'metrics': { + 'train_loss': metrics.train_losses[-1] if metrics.train_losses else None, + 'best_auc': metrics.best_auc, + 'timestamp': timestamp + } + } + state_save_path = checkpoint_dir / 'training_state.pt' + torch.save(training_state, state_save_path) + logger.info(f"Saved training state to {state_save_path}") + + # Save config + config_save_path = checkpoint_dir / 'config.json' + with open(config_save_path, 'w') as f: + json.dump(config.to_serializable_dict(), f, indent=2) + logger.info(f"Saved config to {config_save_path}") + + # Save checkpoint metadata + metadata = { + 'timestamp': timestamp, + 'epoch': epoch, + 'model_size': os.path.getsize(model_save_path) / (1024 * 1024), # Size in MB + 'git_commit': os.environ.get('GIT_COMMIT', 'unknown'), + 'training_metrics': { + 'loss': metrics.train_losses[-1] if metrics.train_losses else None, + 'best_auc': metrics.best_auc + } + } + meta_save_path = checkpoint_dir / 'metadata.json' + with open(meta_save_path, 'w') as f: + json.dump(metadata, f, indent=2) + logger.info(f"Saved checkpoint metadata to {meta_save_path}") + + # Only create symlink after all files are saved successfully + latest_path = base_dir / 'latest' + if latest_path.exists(): + latest_path.unlink() # Remove existing symlink if it exists + + # Create relative symlink + os.symlink(checkpoint_dir.name, latest_path) + logger.info(f"Updated 'latest' symlink to point to {checkpoint_dir.name}") + + # Cleanup old checkpoints if needed + keep_last_n = 3 # Keep last 3 checkpoints + all_checkpoints = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith('checkpoint')]) + if len(all_checkpoints) > keep_last_n: + for old_checkpoint in all_checkpoints[:-keep_last_n]: + try: + import shutil + shutil.rmtree(old_checkpoint) + logger.info(f"Removed old checkpoint: {old_checkpoint}") + except Exception as e: + logger.warning(f"Failed to remove old checkpoint {old_checkpoint}: {str(e)}") + + logger.info(f"Successfully saved checkpoint for epoch {epoch + 1}") + return checkpoint_dir + + except Exception as e: + logger.error(f"Error saving checkpoint: {str(e)}") + logger.error("Checkpoint save failed with traceback:", exc_info=True) + # If checkpoint save fails, ensure we don't leave a broken symlink + latest_path = base_dir / 'latest' + if latest_path.exists(): + latest_path.unlink() + raise + +def train(model, train_loader, config): + """Train the model""" + global _model, _optimizer, _scheduler + _model = model + + logger.info("Initializing training components...") + logger.info(f"Using gradient accumulation with {config.grad_accum_steps} steps") + logger.info(f"Effective batch size: {config.batch_size * config.grad_accum_steps}") + + # Initialize gradient scaler for mixed precision + logger.info("Setting up gradient scaler...") + scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) + + logger.info("Creating optimizer...") + optimizer = torch.optim.AdamW( + config.get_param_groups(model), + weight_decay=config.weight_decay + ) + _optimizer = optimizer + + # Calculate total steps for cosine scheduler + total_steps = (len(train_loader) // config.grad_accum_steps) * config.epochs + warmup_steps = int(total_steps * config.warmup_ratio) + logger.info(f"Training schedule: {total_steps} total steps, {warmup_steps} warmup steps") + logger.info(f"Actual number of batches per epoch: {len(train_loader)}") + + # Initialize cosine scheduler with warm restarts + logger.info("Creating learning rate scheduler...") + scheduler = CosineAnnealingWarmRestarts( + optimizer, + T_0=total_steps // config.num_cycles, + T_mult=1, + eta_min=config.lr * config.min_lr_ratio + ) + _scheduler = scheduler + + # Initialize metrics tracker + metrics = MetricsTracker() + + logger.info("Starting training loop...") + # Training loop + model.train() + + # Verify data loader is properly initialized + try: + logger.info("Verifying data loader...") + test_batch = next(iter(train_loader)) + logger.info(f"Data loader test successful. Batch keys: {list(test_batch.keys())}") + logger.info(f"Input shape: {test_batch['input_ids'].shape}") + logger.info(f"Label shape: {test_batch['labels'].shape}") + except Exception as e: + logger.error(f"Data loader verification failed: {str(e)}") + raise + + for epoch in range(config.epochs): + epoch_loss = 0 + num_batches = 0 + + logger.info(f"Starting epoch {epoch + 1}/{config.epochs}") + + # Create progress bar with additional metrics + progress_bar = tqdm( + train_loader, + desc=f"Epoch {epoch + 1}/{config.epochs}", + dynamic_ncols=True, # Adapt to terminal width + leave=True # Keep progress bar after completion + ) + + optimizer.zero_grad(set_to_none=True) # More efficient gradient clearing + + logger.info("Iterating through batches...") + batch_start_time = time.time() + + for batch_idx, batch in enumerate(progress_bar): + try: + # Log first batch details + if batch_idx == 0: + logger.info("Successfully loaded first batch") + logger.info(f"Batch shapes - input_ids: {batch['input_ids'].shape}, " + f"attention_mask: {batch['attention_mask'].shape}, " + f"labels: {batch['labels'].shape}") + logger.info(f"Memory usage: {torch.cuda.memory_allocated() / 1024**2:.1f}MB") + + # Execute training step + loss = training_step(batch, model, optimizer, scheduler, config, scaler, batch_idx) + + if loss is not None: + epoch_loss += loss + num_batches += 1 + + # Calculate batch processing time + batch_time = time.time() - batch_start_time + + # Format loss string outside of the postfix dict + loss_str = "N/A" if loss is None else f"{loss:.4f}" + + # Update progress bar with detailed metrics + progress_bar.set_postfix({ + 'loss': loss_str, + 'lr': f"{scheduler.get_last_lr()[0]:.2e}", + 'batch_time': f"{batch_time:.2f}s", + 'processed': f"{(batch_idx + 1) * config.batch_size}" + }) + + # Log to wandb with more frequent updates + if (batch_idx + 1) % max(1, config.grad_accum_steps // 2) == 0: + try: + wandb.log({ + 'batch_loss': loss if loss is not None else 0, + 'learning_rate': scheduler.get_last_lr()[0], + 'batch_time': batch_time, + 'gpu_memory': torch.cuda.memory_allocated() / 1024**2 + }) + except Exception as e: + logger.warning(f"Could not log to wandb: {str(e)}") + + # More frequent logging for debugging + if batch_idx % 10 == 0: + loss_debug_str = "N/A" if loss is None else f"{loss:.4f}" + logger.debug( + f"Batch {batch_idx}/{len(train_loader)}: " + f"Loss={loss_debug_str}, " + f"Time={batch_time:.2f}s" + ) + + # Memory management + if batch_idx % config.gc_frequency == 0: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + batch_start_time = time.time() + + except Exception as e: + logger.error(f"Error in batch {batch_idx}: {str(e)}") + logger.error("Batch contents:") + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + logger.error(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}") + else: + logger.error(f"{k}: type={type(v)}") + if torch.cuda.is_available(): + logger.error(f"GPU Memory: {torch.cuda.memory_allocated() / 1024**2:.1f}MB") + continue + + # Calculate average epoch loss + avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else float('inf') + metrics.update_train(avg_epoch_loss) + logger.info(f"Epoch {epoch + 1} completed. Average loss: {avg_epoch_loss:.4f}") + + # Save checkpoint + try: + save_checkpoint(model, optimizer, scheduler, metrics, config, epoch) + logger.info(f"Saved checkpoint for epoch {epoch + 1}") + except Exception as e: + logger.error(f"Could not save checkpoint: {str(e)}") + + # Log epoch metrics + try: + wandb.log({ + 'epoch': epoch + 1, + 'epoch_loss': avg_epoch_loss, + 'best_auc': metrics.best_auc, + 'learning_rate': scheduler.get_last_lr()[0], + 'gpu_memory': torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0 + }) + except Exception as e: + logger.error(f"Could not log epoch metrics to wandb: {str(e)}") + +def create_dataloaders(train_dataset, val_dataset, config): + """Create DataLoader with simplified settings""" + logger.info("Creating data loader...") + + # Create sampler + train_sampler = MultilabelStratifiedSampler( + labels=train_dataset.labels, + groups=train_dataset.langs, + batch_size=config.batch_size + ) + + # Create DataLoader with minimal settings + train_loader = DataLoader( + train_dataset, + batch_size=config.batch_size, + sampler=train_sampler, + num_workers=0, # Disable multiprocessing for now + pin_memory=torch.cuda.is_available(), + drop_last=False + ) + + # Verify DataLoader + logger.info("Testing DataLoader...") + try: + test_batch = next(iter(train_loader)) + logger.info("DataLoader test successful") + return train_loader + except Exception as e: + logger.error(f"DataLoader test failed: {str(e)}") + raise + +def main(): + try: + # Set environment variables for CUDA and multiprocessing + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' + os.environ['OMP_NUM_THREADS'] = '1' # Limit OpenMP threads + os.environ['MKL_NUM_THREADS'] = '1' # Limit MKL threads + + logger.info("Initializing training configuration...") + # Initialize config first + config = TrainingConfig() + + # Initialize CUDA settings + if torch.cuda.is_available(): + # Disable TF32 on Ampere GPUs + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + # Set deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Clear CUDA cache + torch.cuda.empty_cache() + + # Set device to current CUDA device + torch.cuda.set_device(torch.cuda.current_device()) + + logger.info(f"Using CUDA device: {torch.cuda.get_device_name()}") + logger.info("Configured CUDA settings for stability") + + # Initialize wandb + try: + wandb.init( + project="toxic-comment-classification", + name=f"toxic-{datetime.now().strftime('%Y%m%d-%H%M%S')}", + config=config.to_serializable_dict() + ) + logger.info("Initialized wandb logging") + except Exception as e: + logger.warning(f"Could not initialize wandb: {str(e)}") + + global _model, _optimizer, _scheduler + _model = None + _optimizer = None + _scheduler = None + + logger.info("Loading datasets...") + try: + train_df = pd.read_csv("dataset/split/train.csv") + logger.info(f"Loaded train dataset with {len(train_df)} samples") + except Exception as e: + logger.error(f"Error loading datasets: {str(e)}") + raise + + try: + logger.info("Creating tokenizer and dataset...") + tokenizer = XLMRobertaTokenizer.from_pretrained(config.model_name) + train_dataset = ToxicDataset(train_df, tokenizer, config) + logger.info("Dataset creation successful") + except Exception as e: + logger.error(f"Error creating datasets: {str(e)}") + raise + + logger.info("Creating data loaders...") + train_loader = create_dataloaders(train_dataset, None, config) + + logger.info("Initializing model...") + model = init_model(config) + + logger.info("Starting training...") + train(model, train_loader, config) + + except KeyboardInterrupt: + print("\nTraining interrupted by user") + cleanup() + except Exception as e: + print(f"Error during training: {str(e)}") + import traceback + traceback.print_exc() + raise + finally: + if wandb.run is not None: + try: + wandb.finish() + except Exception as e: + print(f"Warning: Could not finish wandb run: {str(e)}") + cleanup() + +if __name__ == "__main__": + # Set global PyTorch settings + torch.set_num_threads(1) # Limit CPU threads + np.set_printoptions(precision=4, suppress=True) + torch.set_printoptions(precision=4, sci_mode=False) + + try: + main() + except Exception as e: + print(f"Fatal error: {str(e)}") + cleanup() + sys.exit(1) \ No newline at end of file diff --git a/model/training_config.py b/model/training_config.py new file mode 100644 index 0000000000000000000000000000000000000000..353c8fd068bd59d51dddc8c12632279d04c95f7b --- /dev/null +++ b/model/training_config.py @@ -0,0 +1,476 @@ +# training_config.py +from asyncio.log import logger +from dataclasses import dataclass +from typing import Dict, List +import json +import torch +import numpy as np +from pathlib import Path +from contextlib import nullcontext +from dataclasses import asdict +import os + +@dataclass +class DynamicClassWeights: + """Handles class weights per language using dynamic batch statistics""" + weights_file: str = 'weights/language_class_weights.json' + + def __init__(self, weights_file: str = 'weights/language_class_weights.json'): + self.weights_file = weights_file + self.toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + self.language_columns = ['en', 'es', 'fr', 'it', 'tr', 'pt', 'ru'] + + # Initialize base scaling factors from file if available + try: + with open(self.weights_file, 'r') as f: + data = json.load(f) + self.lang_scaling = {} + for lang in self.language_columns: + if lang in data['weights']: + # Calculate average scaling per language + scales = [float(data['weights'][lang][label]['1']) + for label in self.toxicity_labels] + self.lang_scaling[lang] = sum(scales) / len(scales) + else: + self.lang_scaling[lang] = 1.0 + except Exception as e: + logger.warning(f"Could not load weights from {self.weights_file}: {str(e)}") + self._initialize_defaults() + + # Initialize running statistics for each language + self.running_stats = {lang: { + 'pos_counts': torch.zeros(len(self.toxicity_labels)), + 'total_counts': torch.zeros(len(self.toxicity_labels)), + 'smoothing_factor': 0.95 # EMA smoothing factor + } for lang in self.language_columns} + + def _initialize_defaults(self): + """Initialize safe default scaling factors""" + self.lang_scaling = {lang: 1.0 for lang in self.language_columns} + + def _update_running_stats(self, langs, labels): + """Update running statistics for each language""" + unique_langs = set(langs) + for lang in unique_langs: + if lang not in self.running_stats: + continue + + lang_mask = torch.tensor([l == lang for l in langs], dtype=torch.bool) + lang_labels = labels[lang_mask] + + if len(lang_labels) == 0: + continue + + # Calculate current batch statistics + pos_count = lang_labels.sum(dim=0).float() + total_count = torch.full_like(pos_count, len(lang_labels)) + + # Update running statistics with EMA + alpha = self.running_stats[lang]['smoothing_factor'] + self.running_stats[lang]['pos_counts'] = ( + alpha * self.running_stats[lang]['pos_counts'] + + (1 - alpha) * pos_count + ) + self.running_stats[lang]['total_counts'] = ( + alpha * self.running_stats[lang]['total_counts'] + + (1 - alpha) * total_count + ) + + def get_weights_for_batch(self, langs: List[str], labels: torch.Tensor, device: torch.device) -> Dict[str, torch.Tensor]: + """ + Calculate dynamic weights and focal parameters based on batch and historical statistics + Args: + langs: List of language codes + labels: Binary labels tensor [batch_size, num_labels] + device: Target device for tensors + Returns: + Dict with weights, alpha, and gamma tensors + """ + try: + batch_size = len(langs) + num_labels = labels.size(1) + + # Update running statistics + self._update_running_stats(langs, labels) + + # Calculate positive ratio per language in current batch + lang_pos_ratios = {} + batch_pos_ratios = torch.zeros(num_labels, device=device) + lang_counts = {} + + for lang in set(langs): + lang_mask = torch.tensor([l == lang for l in langs], dtype=torch.bool, device=device) + if not lang_mask.any(): + continue + + # Calculate language-specific positive ratio + lang_labels = labels[lang_mask] + lang_pos_ratio = lang_labels.float().mean(dim=0) + lang_pos_ratios[lang] = lang_pos_ratio + + # Weighted contribution to batch statistics + lang_count = lang_mask.sum() + lang_counts[lang] = lang_count + batch_pos_ratios += lang_pos_ratio * (lang_count / batch_size) + + # Combine batch and historical statistics + weights = torch.ones(batch_size, num_labels, device=device) + alpha = torch.zeros(num_labels, device=device) + gamma = torch.zeros(num_labels, device=device) + + for i, (lang, label_vec) in enumerate(zip(langs, labels)): + if lang not in self.running_stats: + continue + + # Get historical statistics for this language + hist_pos_ratio = ( + self.running_stats[lang]['pos_counts'] / + (self.running_stats[lang]['total_counts'] + 1e-7) + ).to(device) + + # Combine historical and current batch statistics + current_pos_ratio = lang_pos_ratios.get(lang, batch_pos_ratios) + combined_pos_ratio = 0.7 * hist_pos_ratio + 0.3 * current_pos_ratio + + # Calculate stable weights using log-space + log_ratio = torch.log1p(1.0 / (combined_pos_ratio + 1e-7)) + class_weights = torch.exp(log_ratio.clamp(-2, 2)) + + # Apply language-specific scaling + weights[i] = class_weights * self.lang_scaling.get(lang, 1.0) + + # Update focal parameters + alpha_contrib = 1.0 / (combined_pos_ratio + 1e-7).clamp(0.05, 0.95) + gamma_contrib = log_ratio.clamp(1.0, 4.0) + + # Accumulate weighted contributions + weight = lang_counts.get(lang, 1) / batch_size + alpha += alpha_contrib * weight + gamma += gamma_contrib * weight + + # Apply class-specific adjustments based on statistical analysis + # Order: toxic, severe_toxic, obscene, threat, insult, identity_hate + class_adjustments = { + 'en': [1.0, 1.0, 0.9, 0.85, 1.1, 1.0], # English has more obscene/threat + 'ru': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], # Russian has more insults + 'tr': [1.0, 1.0, 1.0, 1.0, 0.9, 0.95], # Turkish pattern + 'es': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], # Spanish pattern + 'fr': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], # French pattern + 'it': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], # Italian pattern + 'pt': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0] # Portuguese pattern + } + + # Apply adjustments to weights + for i, lang in enumerate(langs): + if lang in class_adjustments: + # Multiply weights by language-specific class adjustments + weights[i] *= torch.tensor(class_adjustments[lang], device=device) + + # Normalize weights to prevent extreme values + weights = weights / weights.mean() + + return { + 'weights': weights.clamp(0.1, 10.0), # Prevent extreme values + 'alpha': alpha.clamp(0.1, 5.0), # [num_labels] + 'gamma': gamma.clamp(1.0, 4.0) # [num_labels] + } + + except Exception as e: + logger.error(f"Error computing batch weights: {str(e)}") + # Fallback to safe default values + return { + 'weights': torch.ones((batch_size, num_labels), device=device), + 'alpha': torch.full((num_labels,), 0.25, device=device), + 'gamma': torch.full((num_labels,), 2.0, device=device) + } + +@dataclass +class MetricsTracker: + """Tracks training and validation metrics with error handling""" + best_auc: float = 0.0 + train_losses: List[float] = None + val_losses: List[float] = None + val_aucs: List[float] = None + epoch_times: List[float] = None + + def __post_init__(self): + self.train_losses = [] + self.val_losses = [] + self.val_aucs = [] + self.epoch_times = [] + + def update_train(self, loss: float): + """Update training metrics with validation""" + try: + if not isinstance(loss, (int, float)) or np.isnan(loss) or np.isinf(loss): + print(f"Warning: Invalid loss value: {loss}") + return + self.train_losses.append(float(loss)) + except Exception as e: + print(f"Warning: Could not update training metrics: {str(e)}") + + def update_validation(self, metrics: Dict) -> bool: + """Update validation metrics with error handling""" + try: + if not isinstance(metrics, dict): + raise ValueError("Metrics must be a dictionary") + + loss = metrics.get('loss', float('inf')) + auc = metrics.get('auc', 0.0) + + if np.isnan(loss) or np.isinf(loss): + print(f"Warning: Invalid loss value: {loss}") + loss = float('inf') + + if np.isnan(auc) or np.isinf(auc): + print(f"Warning: Invalid AUC value: {auc}") + auc = 0.0 + + self.val_losses.append(float(loss)) + self.val_aucs.append(float(auc)) + + # Update best AUC if needed + if auc > self.best_auc: + self.best_auc = auc + return True + return False + + except Exception as e: + print(f"Warning: Could not update validation metrics: {str(e)}") + return False + + def update_time(self, epoch_time: float): + """Update timing metrics with validation""" + try: + if not isinstance(epoch_time, (int, float)) or epoch_time <= 0: + print(f"Warning: Invalid epoch time: {epoch_time}") + return + self.epoch_times.append(float(epoch_time)) + except Exception as e: + print(f"Warning: Could not update timing metrics: {str(e)}") + + def get_eta(self, current_epoch: int, total_epochs: int) -> str: + """Calculate ETA based on average epoch time with error handling""" + try: + if not self.epoch_times: + return "Calculating..." + + if current_epoch >= total_epochs: + return "Complete" + + avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times) + remaining_epochs = total_epochs - current_epoch + eta_seconds = avg_epoch_time * remaining_epochs + + hours = int(eta_seconds // 3600) + minutes = int((eta_seconds % 3600) // 60) + + return f"{hours:02d}:{minutes:02d}:00" + + except Exception as e: + print(f"Warning: Could not calculate ETA: {str(e)}") + return "Unknown" + +@dataclass +class TrainingConfig: + """Basic training configuration with consolidated default values""" + # Model parameters + model_name: str = "xlm-roberta-large" + max_length: int = 512 + hidden_size: int = 1024 + num_attention_heads: int = 16 + model_dropout: float = 0.0 + freeze_layers: int = 8 + + # Dataset parameters + cache_dir: str = 'cached_dataset' + label_columns: List[str] = None # Will be initialized in __post_init__ + + # Training parameters + batch_size: int = 128 + grad_accum_steps: int = 1 + epochs: int = 6 + lr: float = 2e-5 + num_cycles: int = 2 + weight_decay: float = 2e-7 + max_grad_norm: float = 1.0 + warmup_ratio: float = 0.1 + label_smoothing: float = 0.01 + min_lr_ratio: float = 0.01 + + # Memory optimization + activation_checkpointing: bool = True + mixed_precision: str = "fp16" + _num_workers: int = None # Private storage for num_workers + gc_frequency: int = 500 + tensor_float_32: bool = True + + # Cosine scheduler parameters + num_cycles: int = 2 + + def __post_init__(self): + """Initialize and validate configuration""" + # Initialize label columns + self.label_columns = [ + 'toxic', 'severe_toxic', 'obscene', + 'threat', 'insult', 'identity_hate' + ] + + # Set environment variables for memory optimization + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True' + + # Rest of the initialization code... + if self.lr <= 0: + raise ValueError(f"Learning rate must be positive, got {self.lr}") + if self.lr < 1e-7: + raise ValueError(f"Learning rate too small: {self.lr}") + if self.lr > 1.0: + raise ValueError(f"Learning rate too large: {self.lr}") + + # Validate weight decay and learning rate combination + if self.weight_decay > 0: + wd_to_lr_ratio = self.weight_decay / self.lr + if wd_to_lr_ratio > 0.1: + logger.warning( + "Weight decay too high: %.2e (%.2fx learning rate). " + "Should be 0.01-0.1x learning rate.", + self.weight_decay, wd_to_lr_ratio + ) + effective_lr = self.lr * (1 - self.weight_decay) + if effective_lr < self.lr * 0.9: + logger.warning( + "Weight decay %.2e reduces effective learning rate to %.2e (%.1f%% reduction)", + self.weight_decay, effective_lr, (1 - effective_lr/self.lr) * 100 + ) + + # Set device with memory optimization + if torch.cuda.is_available(): + try: + torch.cuda.init() + # Set memory allocation strategy + torch.cuda.set_per_process_memory_fraction(0.95) # Leave some GPU memory free + self.device = torch.device('cuda') + + if self.mixed_precision == "bf16": + if not torch.cuda.is_bf16_supported(): + print("Warning: BF16 not supported on this GPU. Falling back to FP16") + self.mixed_precision = "fp16" + + if self.tensor_float_32: + if torch.cuda.get_device_capability()[0] >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + else: + print("Warning: TF32 not supported on this GPU. Disabling.") + self.tensor_float_32 = False + + except Exception as e: + print(f"Warning: CUDA initialization failed: {str(e)}") + self.device = torch.device('cpu') + self.mixed_precision = "no" + else: + self.device = torch.device('cpu') + if self.mixed_precision != "no": + print("Warning: Mixed precision not supported on CPU. Disabling.") + self.mixed_precision = "no" + + # Create directories with error handling + try: + for directory in ["weights", "logs"]: + dir_path = Path(directory) + if not dir_path.exists(): + dir_path.mkdir(parents=True) + elif not dir_path.is_dir(): + raise NotADirectoryError(f"{directory} exists but is not a directory") + except Exception as e: + print(f"Error creating directories: {str(e)}") + raise + + # Initialize toxicity labels + self.toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + self.num_labels = len(self.toxicity_labels) + + # Set use_mixed_precision flag + self.use_mixed_precision = self.mixed_precision != "no" + + def validate_model_config(self, model): + """Validate configuration against model architecture""" + try: + # Validate layer freezing + if self.freeze_layers > 0: + total_layers = len(list(model.base_model.encoder.layer)) + if self.freeze_layers > total_layers: + raise ValueError(f"Can't freeze {self.freeze_layers} layers in {total_layers}-layer model") + logger.info(f"Freezing {self.freeze_layers} out of {total_layers} layers") + + # Validate parameter groups and weight decay + param_groups = self.get_param_groups(model) + if self.weight_decay > 0: + low_lr_groups = [g for g in param_groups if g['lr'] < 0.01] + if low_lr_groups: + logger.warning("Found parameter groups with low learning rates (< 0.01) and non-zero weight decay:") + for group in low_lr_groups: + logger.warning(f"Group with lr={group['lr']:.4f}") + + return True + except Exception as e: + logger.error(f"Model configuration validation failed: {str(e)}") + raise + + @property + def dtype(self) -> torch.dtype: + """Get the appropriate dtype based on mixed precision settings""" + if self.mixed_precision == "bf16": + return torch.bfloat16 + elif self.mixed_precision == "fp16": + return torch.float16 + return torch.float32 + + def get_autocast_context(self): + """Get the appropriate autocast context based on configuration.""" + if not self.use_mixed_precision: + return nullcontext() + dtype = torch.bfloat16 if self.mixed_precision == "bf16" else torch.float16 + return torch.autocast(device_type=self.device.type, dtype=dtype) + + def to_serializable_dict(self): + """Convert config to a dictionary for saving.""" + config_dict = asdict(self) + return config_dict + + def get_param_groups(self, model): + """Get parameter groups with base learning rate""" + return [{'params': model.parameters(), 'lr': self.lr}] + + @property + def use_amp(self): + """Check if AMP should be used based on device and mixed precision setting""" + return self.device.type == 'cuda' and self.mixed_precision != "no" + + @property + def grad_norm_clip(self): + """Adaptive gradient clipping based on precision""" + if self.mixed_precision == "bf16": + return 1.5 # BF16 can handle slightly higher gradients than FP16 + if self.mixed_precision == "fp16": + return 1.0 # Most conservative for FP16 due to lower precision + return 5.0 # Full precision can handle larger gradients + + @property + def num_workers(self): + """Dynamically adjust workers based on system resources""" + if self._num_workers is None: + cpu_count = os.cpu_count() + if cpu_count is None: + self._num_workers = 0 + else: + # Leave at least 2 CPUs free, max 4 workers + self._num_workers = min(4, max(0, cpu_count - 2)) + logger.info(f"Dynamically set num_workers to {self._num_workers} (CPU count: {cpu_count})") + return self._num_workers + + @num_workers.setter + def num_workers(self, value): + """Allow manual override of num_workers""" + self._num_workers = value + logger.info(f"Manually set num_workers to {value}") \ No newline at end of file diff --git a/nohup.out b/nohup.out new file mode 100644 index 0000000000000000000000000000000000000000..581ba85fb8479a109e725034ea15daa814903557 --- /dev/null +++ b/nohup.out @@ -0,0 +1,938 @@ +Starting training with configuration: +====================================== +Error log: logs/error_20250401_104945.log +PYTHONPATH: :/home/deeptanshul/Toxic-Comment-Classification-using-Deep-Learning +====================================== +Starting training in background... +Training process started with PID: 7731 + +Monitor commands: +1. View error log: tail -f logs/error_20250401_104945.log +2. Check process status: ps -p 7731 +3. Stop training: kill 7731 +Warning: TF32 not supported on this GPU. Disabling. +Error during training: Missing label columns in dataset: ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + +Performing cleanup... +Fatal error: Missing label columns in dataset: ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + +Performing cleanup... + +Performing cleanup... +Warning: TF32 not supported on this GPU. Disabling. +Initialized dataset with 285264 samples +Loading sample 147000 +Loading sample 225000 +Loading sample 4000 +Loading sample 86000 +Loading sample 50000 +Loading sample 144000 +Loading sample 42000 +Loading sample 244000 +Loading sample 229000 +Loading sample 210000 +Loading sample 116000 +Loading sample 278000 +Loading sample 154000 +Loading sample 227000 +Loading sample 272000 +Loading sample 224000 +Loading sample 237000 +Loading sample 77000 +Loading sample 134000 +Loading sample 201000 +Loading sample 211000 +Loading sample 65000 +Loading sample 231000 +Loading sample 194000 +Loading sample 200000 +Loading sample 153000 +Loading sample 211000 +Loading sample 195000 +Loading sample 59000 +Loading sample 134000 +Loading sample 5000 +Loading sample 264000 +Loading sample 9000 +Loading sample 273000 +Loading sample 114000 +Loading sample 20000 +Loading sample 240000 +Loading sample 39000 +Loading sample 195000 +Loading sample 263000 +Loading sample 265000 +Loading sample 119000 +Loading sample 15000 +Loading sample 30000 +Loading sample 141000 +Loading sample 28000 +Loading sample 94000 +Loading sample 157000 +Loading sample 185000 +Loading sample 227000 +Loading sample 132000 +Loading sample 152000 +Loading sample 15000 +Loading sample 192000 +Loading sample 211000 +Loading sample 173000 +Loading sample 67000 +Loading sample 200000 +Loading sample 52000 +Loading sample 280000 +Loading sample 0 +Loading sample 157000 +Loading sample 72000 +Loading sample 278000 +Loading sample 198000 +Loading sample 179000 +Loading sample 27000 +Loading sample 33000 +Loading sample 221000 +Loading sample 231000 +Loading sample 144000 +Loading sample 235000 +Loading sample 42000 +Loading sample 155000 +Loading sample 155000 +Loading sample 8000 +Loading sample 201000 +Loading sample 191000 +Loading sample 151000 +Loading sample 71000 +Loading sample 218000 +Loading sample 283000 +Loading sample 171000 +Loading sample 47000 +Loading sample 57000 +Loading sample 244000 +Loading sample 245000 +Loading sample 211000 +Loading sample 28000 +Loading sample 253000 +Loading sample 35000 +Loading sample 205000 +Loading sample 179000 +Loading sample 50000 +Loading sample 111000 +Loading sample 85000 +Loading sample 30000 +Loading sample 97000 +Loading sample 254000 +Loading sample 10000 +Loading sample 136000 +Loading sample 52000 +Loading sample 85000 +Loading sample 1000 +Loading sample 220000 +Loading sample 165000 +Loading sample 234000 +Loading sample 162000 +Loading sample 270000 +Loading sample 92000 +Loading sample 29000 +Loading sample 105000 +Loading sample 60000 +Loading sample 85000 +Loading sample 11000 +Loading sample 8000 +Loading sample 192000 +Loading sample 46000 +Loading sample 65000 +Loading sample 166000 +Loading sample 110000 +Loading sample 14000 +Loading sample 95000 +Loading sample 149000 +Loading sample 24000 +Loading sample 122000 +Loading sample 184000 +Loading sample 266000 +Loading sample 48000 +Loading sample 259000 +Loading sample 275000 +Loading sample 65000 +Loading sample 224000 +Loading sample 250000 +Loading sample 161000 +Loading sample 128000 +Loading sample 87000 +Loading sample 17000 +Loading sample 280000 +Loading sample 152000 +Loading sample 35000 +Loading sample 228000 +Loading sample 27000 +Loading sample 209000 +Loading sample 261000 +Loading sample 197000 +Loading sample 210000 +Loading sample 260000 +Loading sample 256000 +Loading sample 204000 +Loading sample 276000 +Loading sample 266000 +Loading sample 229000 +Loading sample 0 +Loading sample 180000 +Loading sample 66000 +Loading sample 198000 +Loading sample 267000 +Loading sample 79000 +Loading sample 281000 +Loading sample 234000 +Loading sample 168000 +Loading sample 113000 +Loading sample 153000 +Loading sample 59000 +Loading sample 82000 +Loading sample 254000 +Loading sample 168000 +Loading sample 201000 +Loading sample 183000 +Loading sample 56000 +Loading sample 54000 +Loading sample 116000 +Loading sample 42000 +Loading sample 141000 +Loading sample 247000 +Loading sample 201000 +Loading sample 259000 +Loading sample 123000 +Loading sample 15000 +Loading sample 235000 +Loading sample 58000 +Loading sample 89000 +Loading sample 176000 +Loading sample 117000 +Loading sample 149000 +Loading sample 121000 +Loading sample 33000 +Loading sample 118000 +Loading sample 71000 +Loading sample 53000 +Loading sample 25000 +Loading sample 180000 +Loading sample 112000 +Loading sample 222000 +Loading sample 199000 +Loading sample 37000 +Loading sample 56000 +Loading sample 145000 +Loading sample 60000 +Loading sample 187000 +Loading sample 242000 +Loading sample 49000 +Loading sample 46000 +Loading sample 251000 +Loading sample 274000 +Loading sample 122000 +Loading sample 121000 +Loading sample 19000 +Loading sample 102000 +Loading sample 229000 +Loading sample 145000 +Loading sample 35000 +Loading sample 130000 +Loading sample 57000 +Loading sample 135000 +Loading sample 169000 +Loading sample 74000 +Loading sample 243000 +Loading sample 114000 +Loading sample 255000 +Loading sample 212000 +Loading sample 206000 +Loading sample 26000 +Loading sample 212000 +Loading sample 270000 +Loading sample 54000 +Loading sample 40000 +Loading sample 95000 +Loading sample 277000 +Loading sample 37000 +Loading sample 190000 +Loading sample 175000 +Loading sample 100000 +Loading sample 107000 +Loading sample 280000 +Loading sample 13000 +Loading sample 200000 +Loading sample 272000 +Loading sample 61000 +Loading sample 92000 +Loading sample 60000 +Loading sample 101000 +Loading sample 171000 +Loading sample 23000 +Loading sample 156000 +Loading sample 101000 +Loading sample 170000 +Loading sample 258000 +Loading sample 0 +Loading sample 71000 +Loading sample 236000 +Loading sample 22000 +Loading sample 7000 +Loading sample 25000 +Loading sample 95000 +Loading sample 77000 +Loading sample 85000 +Loading sample 144000 +Loading sample 38000 +Loading sample 24000 +Loading sample 87000 +Loading sample 201000 +Loading sample 70000 +Loading sample 12000 +Loading sample 100000 +Loading sample 223000 +Loading sample 209000 +Loading sample 272000 +Loading sample 233000 +Loading sample 2000 +Loading sample 206000 +Loading sample 55000 +Loading sample 110000 +Loading sample 271000 +Loading sample 163000 +Loading sample 198000 +Loading sample 109000 +Loading sample 39000 +Loading sample 228000 +Loading sample 181000 +Loading sample 231000 +Loading sample 158000 +Loading sample 272000 +Loading sample 105000 +Loading sample 92000 +Loading sample 225000 +Loading sample 213000 +Loading sample 38000 +Loading sample 258000 +Loading sample 209000 +Loading sample 172000 +Loading sample 137000 +Loading sample 187000 +Loading sample 38000 +Loading sample 93000 +Loading sample 42000 +Loading sample 53000 +Loading sample 165000 +Loading sample 222000 +Loading sample 68000 +Loading sample 224000 +Loading sample 23000 +Loading sample 207000 +Loading sample 177000 +Loading sample 108000 +Loading sample 261000 +Loading sample 205000 +Loading sample 164000 +Loading sample 132000 +Loading sample 126000 +Loading sample 282000 +Loading sample 32000 +Loading sample 263000 +Loading sample 157000 +Loading sample 28000 +Loading sample 4000 +Loading sample 103000 +Loading sample 181000 +Loading sample 27000 +Loading sample 35000 +Loading sample 100000 +Loading sample 3000 +Loading sample 262000 +Loading sample 187000 +Loading sample 148000 +Loading sample 6000 +Loading sample 58000 +Loading sample 157000 +Loading sample 120000 +Loading sample 62000 +Loading sample 242000 +Loading sample 61000 +Loading sample 145000 +Loading sample 237000 +Loading sample 66000 +Loading sample 141000 +Loading sample 54000 +Loading sample 62000 +Loading sample 189000 +Loading sample 85000 +Loading sample 39000 +Loading sample 80000 +Loading sample 231000 +Loading sample 260000 +Loading sample 121000 +Loading sample 210000 +Loading sample 233000 +Loading sample 194000 +Loading sample 204000 +Loading sample 37000 +Loading sample 228000 +Loading sample 259000 +Loading sample 129000 +Loading sample 188000 +Loading sample 77000 +Loading sample 127000 +Loading sample 278000 +Loading sample 256000 +Loading sample 263000 +Loading sample 232000 +Loading sample 242000 +Loading sample 50000 +Loading sample 154000 +Loading sample 76000 +Loading sample 199000 +Loading sample 177000 +Loading sample 223000 +Loading sample 222000 +Loading sample 0 +Loading sample 209000 +Loading sample 62000 +Loading sample 250000 +Loading sample 8000 +Loading sample 161000 +Loading sample 45000 +Loading sample 155000 +Loading sample 86000 +Loading sample 261000 +Loading sample 71000 +Loading sample 268000 +Loading sample 36000 +Loading sample 209000 +Loading sample 64000 +Loading sample 106000 +Loading sample 89000 +Loading sample 8000 +Loading sample 199000 +Loading sample 177000 +Loading sample 247000 +Loading sample 134000 +Loading sample 127000 +Loading sample 218000 +Loading sample 162000 +Loading sample 84000 +Loading sample 94000 +Loading sample 56000 +Loading sample 98000 +Loading sample 196000 +Loading sample 109000 +Loading sample 110000 +Loading sample 265000 +Loading sample 52000 +Loading sample 204000 +Loading sample 57000 +Loading sample 110000 +Loading sample 225000 +Loading sample 263000 +Loading sample 261000 +Loading sample 174000 +Loading sample 239000 +Loading sample 99000 +Loading sample 37000 +Loading sample 285000 +Loading sample 199000 +Loading sample 12000 +Loading sample 197000 +Loading sample 87000 +Loading sample 251000 +Loading sample 116000 +Loading sample 155000 +Loading sample 212000 +Loading sample 84000 +Loading sample 256000 +Loading sample 37000 +Loading sample 37000 +Loading sample 45000 +Loading sample 177000 +Loading sample 75000 +Loading sample 138000 +Loading sample 210000 +Loading sample 37000 +Loading sample 230000 +Loading sample 105000 +Loading sample 213000 +Loading sample 225000 +Loading sample 185000 +Loading sample 22000 +Loading sample 10000 +Loading sample 20000 +Loading sample 277000 +Loading sample 161000 +Loading sample 213000 +Loading sample 260000 +Loading sample 152000 +Loading sample 136000 +Loading sample 126000 +Loading sample 51000 +Loading sample 45000 +Loading sample 93000 +Loading sample 154000 +Loading sample 285000 +Loading sample 246000 +Loading sample 58000 +Loading sample 211000 +Loading sample 224000 +Loading sample 16000 +Loading sample 152000 +Loading sample 266000 +Loading sample 234000 +Loading sample 98000 +Loading sample 119000 +Loading sample 243000 +Loading sample 26000 +Loading sample 116000 +Loading sample 115000 +Loading sample 185000 +Loading sample 275000 +Loading sample 17000 +Loading sample 36000 +Loading sample 141000 +Loading sample 82000 +Loading sample 204000 +Loading sample 45000 +Loading sample 73000 +Loading sample 58000 +Loading sample 17000 +Loading sample 177000 +Loading sample 201000 +Loading sample 237000 +Loading sample 226000 +Loading sample 143000 +Loading sample 11000 +Loading sample 279000 +Loading sample 214000 +Loading sample 81000 +Loading sample 106000 +Loading sample 196000 +Loading sample 251000 +Loading sample 176000 +Loading sample 189000 +Loading sample 117000 +Loading sample 87000 +Loading sample 174000 +Loading sample 197000 +Loading sample 128000 +Loading sample 3000 +Loading sample 165000 +Loading sample 263000 +Loading sample 85000 +Loading sample 71000 +Loading sample 88000 +Loading sample 83000 +Loading sample 162000 +Loading sample 250000 +Loading sample 195000 +Loading sample 189000 +Loading sample 204000 +Loading sample 61000 +Loading sample 4000 +Loading sample 103000 +Loading sample 216000 +Loading sample 57000 +Loading sample 48000 +Loading sample 248000 +Loading sample 93000 +Loading sample 70000 +Loading sample 11000 +Loading sample 56000 +Loading sample 36000 +Loading sample 16000 +Loading sample 72000 +Loading sample 155000 +Loading sample 152000 +Loading sample 55000 +Loading sample 250000 +Loading sample 230000 +Loading sample 191000 +Loading sample 220000 +Loading sample 59000 +Loading sample 102000 +Loading sample 45000 +Loading sample 113000 +Loading sample 130000 +Loading sample 67000 +Loading sample 29000 +Loading sample 171000 +Loading sample 178000 +Loading sample 103000 +Loading sample 37000 +Loading sample 48000 +Loading sample 19000 +Loading sample 257000 +Loading sample 58000 +Loading sample 110000 +Loading sample 58000 +Loading sample 42000 +Loading sample 245000 +Loading sample 21000 +Loading sample 238000 +Loading sample 27000 +Loading sample 246000 +Loading sample 73000 +Loading sample 97000 +Loading sample 267000 +Loading sample 15000 +Loading sample 18000 +Loading sample 91000 +Loading sample 103000 +Loading sample 178000 +Loading sample 268000 +Loading sample 194000 +Loading sample 46000 +Loading sample 54000 +Loading sample 47000 +Loading sample 163000 +Loading sample 202000 +Loading sample 144000 +Loading sample 195000 +Loading sample 241000 +Loading sample 56000 +Loading sample 74000 +Loading sample 34000 +Loading sample 182000 +Loading sample 57000 +Loading sample 212000 +Loading sample 75000 +Loading sample 224000 +Loading sample 94000 +Loading sample 98000 +Loading sample 66000 +Loading sample 12000 +Loading sample 10000 +Loading sample 34000 +Loading sample 120000 +Loading sample 48000 +Loading sample 169000 +Loading sample 156000 +Loading sample 152000 +Loading sample 122000 +Loading sample 243000 +Loading sample 52000 +Loading sample 158000 +Loading sample 41000 +Loading sample 31000 +Loading sample 258000 +Loading sample 62000 +Loading sample 3000 +Loading sample 197000 +Loading sample 227000 +Loading sample 257000 +Loading sample 10000 +Loading sample 257000 +Loading sample 249000 +Loading sample 179000 +Loading sample 74000 +Loading sample 174000 +Loading sample 132000 +Loading sample 70000 +Loading sample 219000 +Loading sample 173000 +Loading sample 257000 +Loading sample 191000 +Loading sample 157000 +Loading sample 117000 +Loading sample 241000 +Loading sample 136000 +Loading sample 108000 +Loading sample 169000 +Loading sample 176000 +Loading sample 105000 +Loading sample 120000 +Loading sample 136000 +Loading sample 92000 +Loading sample 79000 +Loading sample 159000 +Loading sample 121000 +Loading sample 36000 +Loading sample 57000 +Loading sample 129000 +Loading sample 86000 +Loading sample 138000 +Loading sample 264000 +Loading sample 39000 +Loading sample 96000 +Loading sample 45000 +Loading sample 163000 +Loading sample 243000 +Loading sample 185000 +Loading sample 41000 +Loading sample 127000 +Loading sample 123000 +Loading sample 68000 +Loading sample 62000 +Loading sample 55000 +Loading sample 278000 +Loading sample 268000 +Loading sample 177000 +Loading sample 258000 +Loading sample 230000 +Loading sample 89000 +Loading sample 261000 +Loading sample 278000 +Loading sample 16000 +Loading sample 110000 +Loading sample 257000 +Loading sample 44000 +Loading sample 110000 +Loading sample 177000 +Loading sample 166000 +Loading sample 144000 +Loading sample 48000 +Loading sample 140000 +Loading sample 273000 +Loading sample 267000 +Loading sample 2000 +Loading sample 54000 +Loading sample 185000 +Loading sample 261000 +Loading sample 71000 +Loading sample 113000 +Loading sample 23000 +Loading sample 219000 +Loading sample 29000 +Loading sample 201000 +Loading sample 86000 +Loading sample 64000 +Loading sample 75000 +Loading sample 261000 +Loading sample 176000 +Loading sample 274000 +Loading sample 56000 +Loading sample 47000 +Loading sample 149000 +Loading sample 264000 +Loading sample 102000 +Loading sample 79000 +Loading sample 35000 +Loading sample 101000 +Loading sample 57000 +Loading sample 138000 +Loading sample 234000 +Loading sample 186000 +Loading sample 84000 +Loading sample 86000 +Loading sample 8000 +Loading sample 34000 +Loading sample 225000 +Loading sample 208000 +Loading sample 67000 +Loading sample 25000 +Loading sample 60000 +Loading sample 35000 +Loading sample 54000 +Loading sample 121000 +Loading sample 200000 +Loading sample 241000 +Loading sample 170000 +Loading sample 196000 +Loading sample 40000 +Loading sample 220000 +Loading sample 241000 +Loading sample 255000 +Loading sample 195000 +Loading sample 10000 +Loading sample 68000 +Loading sample 65000 +Loading sample 47000 +Loading sample 115000 +Loading sample 236000 +Loading sample 246000 +Loading sample 171000 +Loading sample 158000 +Loading sample 95000 +Loading sample 64000 +Loading sample 41000 +Loading sample 76000 +Loading sample 50000 +Loading sample 39000 +Loading sample 99000 +Loading sample 100000 +Loading sample 142000 +Loading sample 192000 +Loading sample 273000 +Loading sample 48000 +Loading sample 136000 +Loading sample 274000 +Loading sample 92000 +Loading sample 259000 +Loading sample 212000 +Loading sample 166000 +Loading sample 182000 +Loading sample 195000 +Loading sample 133000 +Loading sample 135000 +Loading sample 94000 +Loading sample 85000 +Loading sample 251000 +Loading sample 11000 +Loading sample 88000 +Loading sample 188000 +Loading sample 61000 +Loading sample 19000 +Loading sample 204000 +Loading sample 267000 +Loading sample 200000 +Loading sample 110000 +Loading sample 257000 +Loading sample 75000 +Loading sample 252000 +Loading sample 192000 +Loading sample 106000 +Loading sample 146000 +Loading sample 171000 +Loading sample 143000 +Loading sample 154000 +Loading sample 54000 +Loading sample 200000 +Loading sample 198000 +Loading sample 33000 +Loading sample 87000 +Loading sample 168000 +Loading sample 278000 +Loading sample 129000 +Loading sample 77000 +Loading sample 8000 +Loading sample 206000 +Loading sample 90000 +Loading sample 144000 +Loading sample 183000 +Loading sample 15000 +Loading sample 14000 +Loading sample 166000 +Loading sample 133000 +Loading sample 210000 +Loading sample 223000 +Loading sample 257000 +Loading sample 12000 +Loading sample 237000 +Loading sample 266000 +Loading sample 233000 +Loading sample 209000 +Loading sample 204000 +Loading sample 174000 +Loading sample 37000 +Loading sample 219000 +Loading sample 130000 +Loading sample 55000 +Loading sample 115000 +Loading sample 64000 +Loading sample 225000 +Loading sample 108000 +Loading sample 284000 +Loading sample 144000 +Loading sample 54000 +Loading sample 211000 +Loading sample 228000 +Loading sample 136000 +Loading sample 24000 +Loading sample 274000 +Loading sample 277000 +Loading sample 39000 +Loading sample 88000 +Loading sample 176000 +Loading sample 209000 +Loading sample 136000 +Loading sample 87000 +Loading sample 285000 +Loading sample 119000 +Loading sample 250000 +Loading sample 260000 +Loading sample 229000 +Loading sample 156000 +Loading sample 195000 +Loading sample 179000 +Loading sample 219000 +Loading sample 44000 +Loading sample 158000 +Loading sample 184000 +Loading sample 12000 +Loading sample 2000 +Loading sample 142000 +Loading sample 161000 +Loading sample 22000 +Loading sample 41000 +Loading sample 152000 +Loading sample 124000 +Loading sample 174000 +Loading sample 26000 +Loading sample 242000 +Loading sample 213000 +Loading sample 137000 +Loading sample 260000 +Loading sample 217000 +Loading sample 31000 +Loading sample 83000 +Loading sample 103000 +Loading sample 258000 +Loading sample 32000 +Loading sample 185000 +Loading sample 53000 +Loading sample 263000 +Loading sample 141000 +Loading sample 126000 +Loading sample 166000 +Loading sample 218000 +Loading sample 83000 +Loading sample 230000 +Loading sample 235000 +Loading sample 17000 +Loading sample 86000 +Loading sample 42000 +Loading sample 105000 +Loading sample 232000 +Loading sample 23000 +Loading sample 102000 +Loading sample 183000 +Loading sample 46000 +Loading sample 106000 +Loading sample 3000 +Loading sample 134000 +Loading sample 63000 +Loading sample 134000 +Loading sample 156000 +Loading sample 76000 +Loading sample 194000 +Loading sample 88000 +Loading sample 153000 +Loading sample 149000 +Loading sample 155000 +Loading sample 269000 +Loading sample 100000 +Loading sample 33000 +Loading sample 31000 +Loading sample 5000 +Loading sample 109000 +Loading sample 273000 +Loading sample 3000 +Loading sample 223000 +Loading sample 71000 +Loading sample 231000 +Loading sample 234000 +Loading sample 207000 +Loading sample 90000 +Loading sample 42000 +Loading sample 194000 +Loading sample 116000 +Loading sample 170000 +Loading sample 122000 +Loading sample 166000 +Loading sample 219000 +Loading sample 22000 +Loading sample 227000 +Loading sample 45000 +Loading sample 141000 + +Received signal 15. Cleaning up... + +Performing cleanup... + +Performing cleanup... +Warning: Error during cleanup: name '_model' is not defined + +Performing cleanup... +Warning: Error during cleanup: name '_model' is not defined diff --git a/readme.md b/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..a71213a8ad1ac39672e20419061666aaffe131f3 --- /dev/null +++ b/readme.md @@ -0,0 +1,217 @@ +# Toxic Comment Classification using Deep Learning + +A multilingual toxic comment classification system using language-aware transformers and advanced deep learning techniques. + +## 🏗️ Architecture Overview + +### Core Components + +1. **LanguageAwareTransformer** + - Base: XLM-RoBERTa Large + - Custom language-aware attention mechanism + - Gating mechanism for feature fusion + - Language-specific dropout rates + - Support for 7 languages with English fallback + +2. **ToxicDataset** + - Efficient caching system + - Language ID mapping + - Memory pinning for CUDA optimization + - Automatic handling of missing values + +3. **Training System** + - Mixed precision training (BF16/FP16) + - Gradient accumulation + - Language-aware loss weighting + - Distributed training support + - Automatic threshold optimization + +### Key Features + +- **Language Awareness** + - Language-specific embeddings + - Dynamic dropout rates per language + - Language-aware attention mechanism + - Automatic fallback to English for unsupported languages + +- **Performance Optimization** + - Gradient checkpointing + - Memory-efficient attention + - Automatic mixed precision + - Caching system for processed data + - CUDA optimization with memory pinning + +- **Training Features** + - Weighted focal loss with language awareness + - Dynamic threshold optimization + - Early stopping with patience + - Gradient flow monitoring + - Comprehensive metric tracking + +## 📊 Data Processing + +### Input Format +```python +{ + 'comment_text': str, # The text to classify + 'lang': str, # Language code (en, ru, tr, es, fr, it, pt) + 'toxic': int, # Binary labels for each category + 'severe_toxic': int, + 'obscene': int, + 'threat': int, + 'insult': int, + 'identity_hate': int +} +``` + +### Language Support +- Primary: en, ru, tr, es, fr, it, pt +- Default fallback: en (English) +- Language ID mapping: {en: 0, ru: 1, tr: 2, es: 3, fr: 4, it: 5, pt: 6} + +## 🚀 Model Architecture + +### Base Model +- XLM-RoBERTa Large +- Hidden size: 1024 +- Attention heads: 16 +- Max sequence length: 128 + +### Custom Components + +1. **Language-Aware Classifier** +```python +- Input: Hidden states [batch_size, hidden_size] +- Language embeddings: [batch_size, 64] +- Projection: hidden_size + 64 -> 512 +- Output: 6 toxicity predictions +``` + +2. **Language-Aware Attention** +```python +- Input: Hidden states + Language embeddings +- Scaled dot product attention +- Gating mechanism for feature fusion +- Memory-efficient implementation +``` + +## 🛠️ Training Configuration + +### Hyperparameters +```python +{ + "batch_size": 32, + "grad_accum_steps": 2, + "epochs": 4, + "lr": 2e-5, + "weight_decay": 0.01, + "warmup_ratio": 0.1, + "label_smoothing": 0.01, + "model_dropout": 0.1, + "freeze_layers": 2 +} +``` + +### Optimization +- Optimizer: AdamW +- Learning rate scheduler: Cosine with warmup +- Mixed precision: BF16/FP16 +- Gradient clipping: 1.0 +- Gradient accumulation steps: 2 + +## 📈 Metrics and Monitoring + +### Training Metrics +- Loss (per language) +- AUC-ROC (macro) +- Precision, Recall, F1 +- Language-specific metrics +- Gradient norms +- Memory usage + +### Validation Metrics +- AUC-ROC (per class and language) +- Optimal thresholds per language +- Critical class performance (threat, identity_hate) +- Distribution shift monitoring + +## 🔧 Usage + +### Training +```bash +python model/train.py +``` + +### Inference +```python +from model.predict import predict_toxicity + +results = predict_toxicity( + text="Your text here", + model=model, + tokenizer=tokenizer, + config=config +) +``` + +## 🔍 Code Structure + +``` +model/ +├── language_aware_transformer.py # Core model architecture +├── train.py # Training loop and utilities +├── predict.py # Inference utilities +├── evaluation/ +│ ├── evaluate.py # Evaluation functions +│ └── threshold_optimizer.py # Dynamic threshold optimization +├── data/ +│ └── sampler.py # Custom sampling strategies +└── training_config.py # Configuration management +``` + +## 🤖 AI/ML Specific Notes + +1. **Tensor Shapes** + - Input IDs: [batch_size, seq_len] + - Attention Mask: [batch_size, seq_len] + - Language IDs: [batch_size] + - Hidden States: [batch_size, seq_len, hidden_size] + - Language Embeddings: [batch_size, embed_dim] + +2. **Critical Components** + - Language ID handling in forward pass + - Attention mask shape management + - Memory-efficient attention implementation + - Gradient flow in language-aware components + +3. **Performance Considerations** + - Cache management for processed data + - Memory pinning for GPU transfers + - Gradient accumulation for large batches + - Language-specific dropout rates + +4. **Error Handling** + - Language ID validation + - Shape compatibility checks + - Gradient norm monitoring + - Device placement verification + +## 📝 Notes for AI Systems + +1. When modifying the code: + - Maintain language ID handling in forward pass + - Preserve attention mask shape management + - Keep device consistency checks + - Handle BatchEncoding security in PyTorch 2.6+ + +2. Key attention points: + - Language ID tensor shape and type + - Attention mask broadcasting + - Memory-efficient attention implementation + - Gradient flow through language-aware components + +3. Common pitfalls: + - Incorrect attention mask shapes + - Language ID type mismatches + - Memory leaks in caching + - Device inconsistencies diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..670c1571b62974bfcd02fde864e2e3e1771a05d0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,260 @@ +about-time==4.2.1 +absl-py==2.1.0 +accelerate==1.3.0 +affinegap==1.12 +aiofiles==23.2.1 +aiohappyeyeballs==2.4.4 +aiohttp==3.11.11 +aiosignal==1.3.2 +alive-progress==3.2.0 +altair==5.5.0 +annotated-types==0.7.0 +ansicon==1.89.0 +anyio==4.9.0 +astunparse==1.6.3 +async-timeout==5.0.1 +attrs==25.1.0 +beautifulsoup4==4.12.3 +bitsandbytes==0.45.1 +blessed==1.20.0 +blinker==1.9.0 +blis==1.2.0 +BTrees==6.1 +cachetools==5.5.1 +catalogue==2.0.10 +categorical-distance==1.9 +certifi==2025.1.31 +cffi==1.17.1 +chardet==3.0.4 +charset-normalizer==3.4.1 +click==8.1.8 +cloudpathlib==0.20.0 +colorama==0.4.6 +coloredlogs==15.0.1 +confection==0.1.5 +contourpy==1.3.1 +cycler==0.12.1 +cymem==2.0.11 +datasets==3.2.0 +dedupe-Levenshtein-search==1.4.5 +dill==0.3.8 +distlib==0.3.9 +docker-pycreds==0.4.0 +DoubleMetaphone==1.1 +editor==1.6.6 +entrypoints==0.4 +exceptiongroup==1.2.2 +Faker==37.1.0 +fastapi==0.115.12 +favicon==0.7.0 +ffmpy==0.5.0 +filelock==3.17.0 +flatbuffers==25.1.24 +fonttools==4.55.8 +frozenlist==1.5.0 +fsspec==2024.9.0 +gast==0.6.0 +gitdb==4.0.12 +GitPython==3.1.44 +google-api-core==2.24.1 +google-auth==2.38.0 +google-cloud==0.34.0 +google-cloud-core==2.4.1 +google-cloud-translate==3.19.0 +google-pasta==0.2.0 +googleapis-common-protos==1.66.0 +GPUtil==1.4.0 +gradio==5.23.2 +gradio_client==1.8.0 +grapheme==0.6.0 +groovy==0.1.2 +grpc-google-iam-v1==0.14.0 +grpcio==1.70.0 +grpcio-status==1.70.0 +h11==0.14.0 +h2==3.2.0 +h5py==3.12.1 +haversine==2.9.0 +highered==0.2.1 +hpack==3.0.0 +hstspreload==2025.1.1 +htbuilder==0.9.0 +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.28.1 +humanfriendly==10.0 +hyperframe==5.2.0 +idna==2.10 +imbalanced-learn==0.13.0 +inquirer==3.4.0 +iterative-stratification==0.1.9 +Jinja2==3.1.5 +jinxed==1.3.0 +joblib==1.4.2 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +keras==3.8.0 +kiwisolver==1.4.8 +langcodes==3.5.0 +langdetect==1.0.9 +langid==1.1.6 +language_data==1.3.0 +libclang==18.1.1 +lxml==5.3.1 +marisa-trie==1.2.1 +Markdown==3.7 +markdown-it-py==3.0.0 +markdownlit==0.0.7 +MarkupSafe==3.0.2 +matplotlib==3.10.0 +mdurl==0.1.2 +ml-dtypes==0.4.1 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +murmurhash==1.0.12 +namex==0.0.8 +narwhals==1.33.0 +networkx==3.4.2 +nltk==3.9.1 +numpy==1.26.2 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +onnxruntime==1.21.0 +opt_einsum==3.4.0 +optree==0.14.0 +orjson==3.10.16 +packaging==24.2 +pandas==2.1.4 +peft==0.14.0 +persistent==6.1 +phonenumbers==8.13.54 +pillow==11.1.0 +platformdirs==4.3.6 +plotly==6.0.1 +preshed==3.0.9 +presidio_analyzer==2.2.357 +prometheus_client==0.21.1 +propcache==0.2.1 +proto-plus==1.26.0 +protobuf==5.29.3 +psutil==6.1.1 +pyarrow==15.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.1 +pybind11==2.13.6 +pycparser==2.22 +pydantic==2.10.6 +pydantic_core==2.27.2 +pydeck==0.9.1 +pydub==0.25.1 +Pygments==2.19.1 +pyhacrf-datamade==0.2.8 +PyLBFGS==0.2.0.16 +pymdown-extensions==10.14.3 +pyparsing==3.2.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytz==2025.1 +pyuseragents==1.0.5 +PyYAML==6.0.2 +readchar==4.2.1 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.3 +requests-file==2.1.0 +rfc3986==1.5.0 +rich==13.9.4 +rpds-py==0.24.0 +rsa==4.9 +ruff==0.11.2 +runs==1.2.2 +safehttpx==0.1.6 +safeIO==1.2 +safetensors==0.5.2 +scikit-learn==1.6.1 +scipy==1.15.1 +seaborn==0.13.2 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==2.20.0 +setproctitle==1.3.4 +shellingham==1.5.4 +simplecosine==1.2 +six==1.17.0 +sklearn-compat==0.1.3 +smart-open==7.1.0 +smmap==5.0.2 +sniffio==1.3.1 +soupsieve==2.6 +spacy==3.8.4 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.5.1 +st-annotated-text==4.0.2 +st-theme==1.2.3 +starlette==0.46.1 +streamlit==1.44.0 +streamlit-avatar==0.1.3 +streamlit-camera-input-live==0.2.0 +streamlit-card==1.0.2 +streamlit-embedcode==0.1.2 +streamlit-extras==0.6.0 +streamlit-faker==0.0.3 +streamlit-image-coordinates==0.1.9 +streamlit-keyup==0.3.0 +streamlit-toggle-switch==1.0.2 +streamlit-vertical-slider==2.5.5 +sympy==1.13.1 +tenacity==9.0.0 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tensorflow==2.18.0 +tensorflow-io-gcs-filesystem==0.37.1 +termcolor==2.5.0 +thinc==8.3.4 +threadpoolctl==3.5.0 +tldextract==5.1.3 +tokenizers==0.21.0 +toml==0.10.2 +tomlkit==0.13.2 +torch==2.6.0 +tornado==6.4.2 +tqdm==4.67.1 +transformers==4.48.2 +translatepy==2.3 +triton==3.2.0 +TurkishStemmer==1.3 +typer==0.15.1 +typing_extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +uvicorn==0.34.0 +validators==0.34.0 +virtualenv==20.30.0 +wandb==0.19.5 +wasabi==1.1.3 +watchdog==6.0.0 +wcwidth==0.2.13 +weasel==0.4.1 +websockets==15.0.1 +Werkzeug==3.1.3 +wrapt==1.17.2 +xmod==1.8.1 +xxhash==3.5.0 +yarl==1.18.3 +zope.deferredimport==5.0 +zope.index==7.0 +zope.interface==7.2 +zope.proxy==6.1 diff --git a/run_streamlit.sh b/run_streamlit.sh new file mode 100644 index 0000000000000000000000000000000000000000..852e44a9ebeabee961d64674ce2749d3a67af90a --- /dev/null +++ b/run_streamlit.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Streamlit Launcher Script for Toxic Comment Classifier +# This script launches the Streamlit version of the application + +echo "🚀 Starting Toxic Comment Classifier - Streamlit Edition" +echo "📚 Loading model and dependencies..." + +# Check for Python and Streamlit +if ! command -v python3 &> /dev/null; then + echo "❌ Python 3 is not installed. Please install Python 3 to run this application." + exit 1 +fi + +if ! python3 -c "import streamlit" &> /dev/null; then + echo "⚠️ Streamlit not found. Attempting to install dependencies..." + pip install -r requirements.txt +fi + +# Set default environment variables if not already set +export ONNX_MODEL_PATH=${ONNX_MODEL_PATH:-"weights/toxic_classifier.onnx"} +export PYTORCH_MODEL_DIR=${PYTORCH_MODEL_DIR:-"weights/toxic_classifier_xlm-roberta-large"} + +# Set Streamlit environment variables to reduce errors +export STREAMLIT_SERVER_WATCH_ONLY_USER_CONTENT=true +export STREAMLIT_SERVER_HEADLESS=true + +# Suppress TensorFlow warnings +export TF_CPP_MIN_LOG_LEVEL=2 +export TF_ENABLE_ONEDNN_OPTS=0 + +# Run the Streamlit app with disabled hot-reload to avoid PyTorch class errors +echo "✅ Launching Streamlit application..." +streamlit run streamlit_app.py --server.port=8501 --server.address=0.0.0.0 --server.runOnSave=false "$@" \ No newline at end of file diff --git a/streamlit_app.py b/streamlit_app.py new file mode 100644 index 0000000000000000000000000000000000000000..637b69189a81b72c4fdfe596cb5af0d1e3673132 --- /dev/null +++ b/streamlit_app.py @@ -0,0 +1,1578 @@ +# Fix for torch.classes watchdog errors +import sys +class ModuleProtector: + def __init__(self, module_name): + self.module_name = module_name + self.original_module = sys.modules.get(module_name) + + def __enter__(self): + if self.module_name in sys.modules: + self.original_module = sys.modules[self.module_name] + sys.modules[self.module_name] = None + + def __exit__(self, *args): + if self.original_module is not None: + sys.modules[self.module_name] = self.original_module + +# Temporarily remove torch.classes from sys.modules to prevent Streamlit's file watcher from accessing it +with ModuleProtector('torch.classes'): + import streamlit as st + +# Set page configuration - MUST BE THE FIRST STREAMLIT COMMAND +st.set_page_config( + page_title="Multilingual Toxicity Analyzer", + page_icon="", + layout="wide", + initial_sidebar_state="expanded" +) + +# Now import all other dependencies +import torch +import os +import plotly.graph_objects as go +import pandas as pd +from model.inference_optimized import OptimizedToxicityClassifier +import langid +from typing import List, Dict +import time +import psutil +import platform +try: + import cpuinfo +except ImportError: + cpuinfo = None +from streamlit_extras.colored_header import colored_header +from streamlit_extras.add_vertical_space import add_vertical_space +from streamlit_extras.stylable_container import stylable_container +from streamlit_extras.card import card +from streamlit_extras.metric_cards import style_metric_cards + +# Configure paths +ONNX_MODEL_PATH = os.environ.get("ONNX_MODEL_PATH", "weights/toxic_classifier.onnx") +PYTORCH_MODEL_DIR = os.environ.get("PYTORCH_MODEL_DIR", "weights/toxic_classifier_xlm-roberta-large") +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +# Get GPU info if available +def get_gpu_info(): + if DEVICE == "cuda": + try: + gpu_name = torch.cuda.get_device_name(0) + gpu_memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 # Convert to GB + gpu_memory_allocated = torch.cuda.memory_allocated(0) / 1024**3 # Convert to GB + cuda_version = torch.version.cuda + + memory_info = f"{gpu_memory_allocated:.1f}/{gpu_memory_total:.1f} GB" + return f"{gpu_name} (CUDA {cuda_version}, Memory: {memory_info})" + except Exception as e: + return "CUDA device" + return "CPU" + +# Get CPU information +def get_cpu_info(): + try: + cpu_percent = psutil.cpu_percent(interval=0.1) + cpu_count = psutil.cpu_count(logical=True) + cpu_freq = psutil.cpu_freq() + + if cpu_freq: + freq_info = f"{cpu_freq.current/1000:.2f} GHz" + else: + freq_info = "Unknown" + + # Try multiple methods to get CPU model name + cpu_model = None + + # Method 1: Try reading from /proc/cpuinfo directly + try: + with open('/proc/cpuinfo', 'r') as f: + for line in f: + if 'model name' in line: + cpu_model = line.split(':', 1)[1].strip() + break + except: + pass + + # Method 2: If Method 1 fails, try using platform.processor() + if not cpu_model: + cpu_model = platform.processor() + + # Method 3: If still no result, try using platform.machine() + if not cpu_model or cpu_model == '': + cpu_model = platform.machine() + + # Method 4: Final fallback to using psutil + if not cpu_model or cpu_model == '': + try: + import cpuinfo + cpu_model = cpuinfo.get_cpu_info()['brand_raw'] + except: + pass + + # Clean up the model name + if cpu_model: + # Remove common unnecessary parts + replacements = [ + '(R)', '(TM)', '(r)', '(tm)', 'CPU', '@', ' ', 'Processor' + ] + for r in replacements: + cpu_model = cpu_model.replace(r, ' ') + # Clean up extra spaces + cpu_model = ' '.join(cpu_model.split()) + # Limit length + if len(cpu_model) > 40: + cpu_model = cpu_model[:37] + "..." + else: + cpu_model = "Unknown CPU" + + return { + "name": cpu_model, + "cores": cpu_count, + "freq": freq_info, + "usage": f"{cpu_percent:.1f}%" + } + except Exception as e: + return { + "name": "CPU", + "cores": "Unknown", + "freq": "Unknown", + "usage": "Unknown" + } + +# Get RAM information +def get_ram_info(): + try: + ram = psutil.virtual_memory() + ram_total = ram.total / (1024**3) # Convert to GB + ram_used = ram.used / (1024**3) # Convert to GB + ram_percent = ram.percent + + return { + "total": f"{ram_total:.1f} GB", + "used": f"{ram_used:.1f} GB", + "percent": f"{ram_percent:.1f}%" + } + except Exception as e: + return { + "total": "Unknown", + "used": "Unknown", + "percent": "Unknown" + } + +# Update system resource information +def update_system_resources(): + cpu_info = get_cpu_info() + ram_info = get_ram_info() + + return { + "cpu": cpu_info, + "ram": ram_info + } + +# Initialize system information +GPU_INFO = get_gpu_info() +SYSTEM_INFO = update_system_resources() + +# Add a function to update GPU memory info in real-time +def update_gpu_info(): + if DEVICE == "cuda": + try: + gpu_memory_allocated = torch.cuda.memory_allocated(0) / 1024**3 # Convert to GB + gpu_memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 # Convert to GB + return f"{gpu_memory_allocated:.1f}/{gpu_memory_total:.1f} GB" + except: + return "N/A" + return "N/A" + +# Helper function to convert hex to rgba +def hex_to_rgba(hex_color, alpha=1.0): + hex_color = hex_color.lstrip('#') + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + return f'rgba({r}, {g}, {b}, {alpha})' + +# Supported languages with emoji flags +SUPPORTED_LANGUAGES = { + 'en': {'name': 'English', 'flag': '🇺🇸'}, + 'ru': {'name': 'Russian', 'flag': '🇷🇺'}, + 'tr': {'name': 'Turkish', 'flag': '🇹🇷'}, + 'es': {'name': 'Spanish', 'flag': '🇪🇸'}, + 'fr': {'name': 'French', 'flag': '🇫🇷'}, + 'it': {'name': 'Italian', 'flag': '🇮🇹'}, + 'pt': {'name': 'Portuguese', 'flag': '🇵🇹'} +} + +# Language examples - expanded with multiple examples per language, categorized as toxic or non-toxic +LANGUAGE_EXAMPLES = { + 'en': { + 'toxic': [ + "You are such an idiot, nobody likes your stupid content.", + "Shut up you worthless piece of garbage. Everyone hates you.", + "This is the most pathetic thing I've ever seen. Only losers would think this is good.", + "Just kill yourself already, the world would be better without you." + ], + 'non_toxic': [ + "I disagree with your opinion, but I appreciate your perspective.", + "This content could use some improvement, but I see the effort you put into it.", + "While I don't personally enjoy this type of content, others might find it valuable.", + "Thank you for sharing your thoughts on this complex topic." + ] + }, + 'ru': { + 'toxic': [ + "Ты полный придурок, твой контент никому не нравится.", + "Заткнись, бесполезный кусок мусора. Все тебя ненавидят.", + "Это самая жалкая вещь, которую я когда-либо видел. Только неудачники думают, что это хорошо.", + "Почему бы тебе просто не исчезнуть нахрен? Никто не будет скучать по тебе." + ], + 'non_toxic': [ + "Я не согласен с вашим мнением, но уважаю вашу точку зрения.", + "Этот контент можно улучшить, но я вижу, сколько усилий вы в него вложили.", + "Хотя мне лично не нравится такой контент, другие могут найти его полезным.", + "Спасибо, что поделились своими мыслями на эту сложную тему." + ] + }, + 'tr': { + 'toxic': [ + "Sen tam bir aptalsın, kimse senin aptalca içeriğini beğenmiyor.", + "Kapa çeneni değersiz çöp parçası. Herkes senden nefret ediyor.", + "Bu gördüğüm en acıklı şey. Sadece lanet olası kaybedenler bunun iyi olduğunu düşünür.", + "Dünya sensiz daha iyi olurdu, kaybol git." + ], + 'non_toxic': [ + "Fikrinize katılmıyorum ama bakış açınızı takdir ediyorum.", + "Bu içerik biraz geliştirilebilir, ancak gösterdiğiniz çabayı görüyorum.", + "Şahsen bu tür içerikten hoşlanmasam da, başkaları bunu değerli bulabilir.", + "Bu karmaşık konu hakkındaki düşüncelerinizi paylaştığınız için teşekkür ederim." + ] + }, + 'es': { + 'toxic': [ + "Eres un completo idiota y nadie te quiere.", + "Cállate, pedazo de basura inútil. Todos te odian.", + "Esto es lo más patético que he visto nunca. Solo los perdedores pensarían que esto es bueno.", + "El mundo estaría mejor sin ti, deberías desaparecer, joder." + ], + 'non_toxic': [ + "No estoy de acuerdo con tu opinión, pero aprecio tu perspectiva.", + "Este contenido podría mejorarse, pero veo el esfuerzo que has puesto en él.", + "Aunque personalmente no disfruto este tipo de contenido, otros podrían encontrarlo valioso.", + "Gracias por compartir tus pensamientos sobre este tema tan complejo." + ] + }, + 'fr': { + 'toxic': [ + "Tu es tellement stupide, personne n'aime ton contenu minable.", + "Ferme-la, espèce de déchet inutile. Tout le monde te déteste.", + "C'est la chose la plus pathétique que j'ai jamais vue. Seuls les loosers penseraient que c'est bien.", + "Le monde serait meilleur sans toi, connard, va-t'en." + ], + 'non_toxic': [ + "Je ne suis pas d'accord avec ton opinion, mais j'apprécie ta perspective.", + "Ce contenu pourrait être amélioré, mais je vois l'effort que tu y as mis.", + "Bien que personnellement je n'apprécie pas ce type de contenu, d'autres pourraient le trouver précieux.", + "Merci d'avoir partagé tes réflexions sur ce sujet complexe." + ] + }, + 'it': { + 'toxic': [ + "Sei un tale idiota, a nessuno piace il tuo contenuto stupido.", + "Chiudi quella bocca, pezzo di spazzatura inutile. Tutti ti odiano.", + "Questa è la cosa più patetica che abbia mai visto. Solo i perdenti penserebbero che sia buona.", + "Il mondo sarebbe migliore senza di te, sparisci." + ], + 'non_toxic': [ + "Non sono d'accordo con la tua opinione, ma apprezzo la tua prospettiva.", + "Questo contenuto potrebbe essere migliorato, ma vedo lo sforzo che ci hai messo.", + "Anche se personalmente non apprezzo questo tipo di contenuto, altri potrebbero trovarlo utile.", + "Grazie per aver condiviso i tuoi pensieri su questo argomento complesso." + ] + }, + 'pt': { + 'toxic': [ + "Você é um idiota completo, ninguém gosta do seu conteúdo estúpido.", + "Cale a boca, seu pedaço de lixo inútil. Todos te odeiam.", + "Isso é a coisa mais patética que eu já vi. Só perdedores pensariam que isso é bom.", + "O mundo seria melhor sem você, desapareça." + ], + 'non_toxic': [ + "Eu discordo da sua opinião, mas aprecio sua perspectiva.", + "Este conteúdo poderia ser melhorado, mas vejo o esforço que você colocou nele.", + "Embora eu pessoalmente não goste deste tipo de conteúdo, outros podem achá-lo valioso.", + "Obrigado por compartilhar seus pensamentos sobre este tema complexo." + ] + } +} + +# Theme colors - Light theme with black text +THEME = { + "primary": "#2D3142", + "background": "#FFFFFF", + "surface": "#FFFFFF", + "text": "#000000", # Changed to pure black for maximum contrast + "text_secondary": "#FFFFFF", # For text that needs to be white + "button": "#000000", # Dark black for buttons + "toxic": "#E53935", # Darker red for better contrast + "non_toxic": "#2E7D32", # Darker green for better contrast + "warning": "#F57C00", # Darker orange for better contrast + "info": "#1976D2", # Darker blue for better contrast + "sidebar_bg": "#FFFFFF", + "card_bg": "white", + "input_bg": "#F8F9FA" +} + +# Custom CSS for better styling +st.markdown(f""" + +""", unsafe_allow_html=True) + +# Custom CSS for metric labels - Add this near the top with the other CSS +st.markdown(f""" + +""", unsafe_allow_html=True) + +# Load model at app start +@st.cache_resource +def load_classifier(): + try: + if os.path.exists(ONNX_MODEL_PATH): + classifier = OptimizedToxicityClassifier(onnx_path=ONNX_MODEL_PATH, device=DEVICE) + st.session_state['model_type'] = 'Loaded' + return classifier + elif os.path.exists(PYTORCH_MODEL_DIR): + classifier = OptimizedToxicityClassifier(pytorch_path=PYTORCH_MODEL_DIR, device=DEVICE) + st.session_state['model_type'] = 'Loaded' + return classifier + else: + st.error(f"❌ No model found at {ONNX_MODEL_PATH} or {PYTORCH_MODEL_DIR}") + return None + except Exception as e: + st.error(f"Error loading model: {str(e)}") + import traceback + st.error(traceback.format_exc()) + return None + +def detect_language(text: str) -> str: + """Detect language of input text""" + try: + lang, _ = langid.classify(text) + return lang if lang in SUPPORTED_LANGUAGES else 'en' + except: + return 'en' + +def predict_toxicity(text: str, selected_language: str = "Auto-detect") -> Dict: + """Predict toxicity of input text""" + if not text or not text.strip(): + return { + "error": "Please enter some text to analyze.", + "results": None + } + + if not st.session_state.get('model_loaded', False): + return { + "error": "Model not loaded. Please check logs.", + "results": None + } + + # Add a spinner while processing + with st.spinner("Analyzing text..."): + # Record start time for inference metrics + start_time = time.time() + + # Detect language if auto-detect is selected + if selected_language == "Auto-detect": + lang_detection_start = time.time() + lang_code = detect_language(text) + lang_detection_time = time.time() - lang_detection_start + detected = True + else: + # Get language code from the display name without flag + selected_name = selected_language.split(' ')[1] if len(selected_language.split(' ')) > 1 else selected_language + lang_code = next((code for code, info in SUPPORTED_LANGUAGES.items() + if info['name'] == selected_name), 'en') + lang_detection_time = 0 + detected = False + + # Run prediction + try: + model_inference_start = time.time() + results = classifier.predict([text], langs=[lang_code])[0] + model_inference_time = time.time() - model_inference_start + total_time = time.time() - start_time + + return { + "results": results, + "detected": detected, + "lang_code": lang_code, + "performance": { + "total_time": total_time, + "lang_detection_time": lang_detection_time, + "model_inference_time": model_inference_time + } + } + except Exception as e: + import traceback + traceback.print_exc() + return { + "error": f"Error processing text: {str(e)}", + "results": None + } + +# Function to set example text +def set_example(lang_code, example_type, example_index=0): + st.session_state['use_example'] = True + # Get the example based on the language, type and index + example = LANGUAGE_EXAMPLES[lang_code][example_type][example_index] + st.session_state['example_text'] = example + st.session_state['detected_lang'] = lang_code + st.session_state['example_info'] = { + 'type': example_type, + 'lang': lang_code, + 'index': example_index + } + +# Initialize session state for example selection if not present +if 'use_example' not in st.session_state: + st.session_state['use_example'] = False + st.session_state['example_text'] = "" + st.session_state['detected_lang'] = "Auto-detect" + st.session_state['example_info'] = None + +# Sidebar content +with st.sidebar: + st.markdown("

Multilingual Toxicity Analyzer

", unsafe_allow_html=True) + + st.markdown(""" + #### This app analyzes text for different types of toxicity across multiple languages with high accuracy. + """) + + # Create language cards with flags + st.markdown("#### Supported Languages:") + lang_cols = st.columns(2) + + for i, (code, info) in enumerate(SUPPORTED_LANGUAGES.items()): + col_idx = i % 2 + with lang_cols[col_idx]: + st.markdown(f"
{info['flag']} {info['name']}
", + unsafe_allow_html=True) + + st.divider() + + # Language selection dropdown moved to sidebar + st.markdown("### 🌐 Select Language") + language_options = ["Auto-detect"] + [f"{info['flag']} {info['name']}" for code, info in SUPPORTED_LANGUAGES.items()] + selected_language = st.selectbox( + "Choose language or use auto-detect", + language_options, + index=0, + key="selected_language", + help="Choose a specific language or use auto-detection" + ) + + # Examples moved to sidebar + st.markdown("### 📝 Try with examples:") + + # Create tabs for toxic and non-toxic examples + example_tabs = st.tabs(["Toxic Examples", "Non-Toxic Examples"]) + + # Order languages by putting the most common ones first + ordered_langs = ['en', 'es', 'fr', 'pt', 'it', 'ru', 'tr'] + + # Toxic examples tab + with example_tabs[0]: + st.markdown('
', unsafe_allow_html=True) + for lang_code in ordered_langs: + info = SUPPORTED_LANGUAGES[lang_code] + with st.expander(f"{info['flag']} {info['name']} examples"): + for i, example in enumerate(LANGUAGE_EXAMPLES[lang_code]['toxic']): + # Display a preview of the example + preview = example[:40] + "..." if len(example) > 40 else example + button_key = f"toxic_{lang_code}_{i}" + button_help = f"Try with this {info['name']} toxic example" + + # We can't directly apply CSS classes to Streamlit buttons, but we can wrap them + if st.button(f"Example {i+1}: {preview}", + key=button_key, + use_container_width=True, + help=button_help): + set_example(lang_code, 'toxic', i) + st.markdown('
', unsafe_allow_html=True) + + # Non-toxic examples tab + with example_tabs[1]: + st.markdown('
', unsafe_allow_html=True) + for lang_code in ordered_langs: + info = SUPPORTED_LANGUAGES[lang_code] + with st.expander(f"{info['flag']} {info['name']} examples"): + for i, example in enumerate(LANGUAGE_EXAMPLES[lang_code]['non_toxic']): + # Display a preview of the example + preview = example[:40] + "..." if len(example) > 40 else example + button_key = f"non_toxic_{lang_code}_{i}" + button_help = f"Try with this {info['name']} non-toxic example" + + if st.button(f"Example {i+1}: {preview}", + key=button_key, + use_container_width=True, + help=button_help): + set_example(lang_code, 'non_toxic', i) + st.markdown('
', unsafe_allow_html=True) + + st.divider() + + # Model and Hardware information in the sidebar with improved layout + st.markdown("### 💻 System Information", unsafe_allow_html=True) + + # Update system resources info + current_sys_info = update_system_resources() + + # GPU section + if DEVICE == "cuda": + st.markdown(""" +
+
🎮 GPU
+
+ """, unsafe_allow_html=True) + + gpu_name = GPU_INFO.split(" (")[0] + st.markdown(f"
Model: {gpu_name}
", unsafe_allow_html=True) + + cuda_version = "Unknown" + if "CUDA" in GPU_INFO: + cuda_version = GPU_INFO.split("CUDA ")[1].split(",")[0] + st.markdown(f"
CUDA: {cuda_version}
", unsafe_allow_html=True) + + current_gpu_memory = update_gpu_info() + st.markdown(f"
Memory: {current_gpu_memory}
", unsafe_allow_html=True) + + st.markdown("
", unsafe_allow_html=True) + + # CPU section + st.markdown(""" +
+
⚙️ CPU
+
+ """, unsafe_allow_html=True) + + cpu_info = current_sys_info["cpu"] + st.markdown(f"
Model: {cpu_info['name']}
", unsafe_allow_html=True) + st.markdown(f"
Cores: {cpu_info['cores']}
", unsafe_allow_html=True) + st.markdown(f"
Frequency: {cpu_info['freq']}
", unsafe_allow_html=True) + st.markdown(f"
Usage: {cpu_info['usage']}
", unsafe_allow_html=True) + + st.markdown("
", unsafe_allow_html=True) + + # RAM section + st.markdown(""" +
+
🧠 RAM
+
+ """, unsafe_allow_html=True) + + ram_info = current_sys_info["ram"] + st.markdown(f"
Total: {ram_info['total']}
", unsafe_allow_html=True) + st.markdown(f"
Used: {ram_info['used']}
", unsafe_allow_html=True) + st.markdown(f"
Usage: {ram_info['percent']}
", unsafe_allow_html=True) + + st.markdown("
", unsafe_allow_html=True) + + st.divider() + + # Toxicity Thresholds - Moved from results section to sidebar + st.markdown("### ⚙️ Toxicity Thresholds") + st.markdown(""" +
+ The model uses language-specific thresholds to determine if a text is toxic: + + - **Toxic**: 60% + - **Severe Toxic**: 54% + - **Obscene**: 60% + - **Threat**: 48% + - **Insult**: 60% + - **Identity Hate**: 50% + + These increased thresholds reduce false positives but may miss borderline toxic content. +
+ """, unsafe_allow_html=True) + +# Display model loading status +if 'model_loaded' not in st.session_state: + with st.spinner("🔄 Loading model..."): + classifier = load_classifier() + if classifier: + st.session_state['model_loaded'] = True + st.success(f"✅ Model loaded successfully on {GPU_INFO}") + else: + st.session_state['model_loaded'] = False + st.error("❌ Failed to load model. Please check logs.") +else: + # Model already loaded, just get it from cache + classifier = load_classifier() + +# Main app +st.markdown(""" +

+ + + + + + Multilingual Toxicity Analyzer +

+""", unsafe_allow_html=True) +st.markdown(""" +

Detect toxic content in multiple languages with state-of-the-art accuracy

+""", unsafe_allow_html=True) + +# Text input area with interactive styling +with stylable_container( + key="text_input_container", + css_styles=f""" + {{ + border-radius: 10px; + overflow: hidden; + transition: all 0.3s ease; + box-shadow: 0 2px 8px rgba(0,0,0,0.15); + background-color: {THEME["card_bg"]}; + padding: 10px; + margin-bottom: 15px; + }} + + textarea {{ + caret-color: black !important; + color: {THEME["text"]} !important; + }} + + /* Ensure the text input cursor is visible */ + .stTextArea textarea {{ + caret-color: black !important; + }} + """ +): + # Get the current example text if it exists + current_example = st.session_state.get('example_text', '') + + # Set the text input value, allowing for modifications + text_input = st.text_area( + "Enter text to analyze", + height=80, + value=current_example if st.session_state.get('use_example', False) else st.session_state.get('text_input', ''), + key="text_input", + help="Enter text in any supported language to analyze for toxicity" + ) + + # Check if the text has been modified from the example + if st.session_state.get('use_example', False) and text_input != current_example: + # Text was modified, clear example state + st.session_state['use_example'] = False + st.session_state['example_text'] = "" + st.session_state['example_info'] = None + +# Analyze button with improved styling in a more compact layout +col1, col2, col3 = st.columns([1, 2, 1]) +with col2: + analyze_button = st.button( + "Analyze Text", + type="primary", + use_container_width=True, + help="Click to analyze the entered text for toxicity" + ) + +# Process when button is clicked or text is submitted +if analyze_button or (text_input and 'last_analyzed' not in st.session_state or st.session_state.get('last_analyzed') != text_input): + if text_input: + st.session_state['last_analyzed'] = text_input + + # Get system resource info before prediction + pre_prediction_resources = update_system_resources() + + # Make prediction + prediction = predict_toxicity(text_input, selected_language) + + # Update resource usage after prediction + post_prediction_resources = update_system_resources() + + # Calculate resource usage delta + resource_delta = { + "cpu_usage": float(post_prediction_resources["cpu"]["usage"].rstrip("%")) - float(pre_prediction_resources["cpu"]["usage"].rstrip("%")), + "ram_usage": float(post_prediction_resources["ram"]["percent"].rstrip("%")) - float(pre_prediction_resources["ram"]["percent"].rstrip("%")) + } + + # Update GPU memory info after prediction + if DEVICE == "cuda": + new_memory_info = update_gpu_info() + # Note: Ideally we would update the displayed memory usage here, + # but Streamlit doesn't support dynamic updates without a rerun, + # so we'll just include memory info in our metrics + + # Set analysis status flags but remove celebration effect code + st.session_state['is_analysis_complete'] = True + st.session_state['analysis_has_error'] = "error" in prediction and prediction["error"] + + if "error" in prediction and prediction["error"]: + st.error(prediction["error"]) + elif prediction["results"]: + # Remove celebration effect call + # celebration_effect() + + results = prediction["results"] + performance = prediction.get("performance", {}) + + # Overall toxicity result + is_toxic = results["is_toxic"] + result_color = THEME["toxic"] if is_toxic else THEME["non_toxic"] + result_text = "TOXIC" if is_toxic else "NON-TOXIC" + + # Language info + lang_code = prediction["lang_code"] + lang_info = SUPPORTED_LANGUAGES.get(lang_code, {"name": lang_code, "flag": "🌐"}) + + # Count toxic categories + toxic_count = len(results["toxic_categories"]) if is_toxic else 0 + + # Create data for visualization but don't display the table + categories = [] + probabilities = [] + statuses = [] + + # Use the same thresholds that are used in the inference model + category_thresholds = { + 'toxic': 0.60, + 'severe_toxic': 0.54, + 'obscene': 0.60, + 'threat': 0.48, + 'insult': 0.60, + 'identity_hate': 0.50 + } + + for label, prob in results["probabilities"].items(): + categories.append(label.replace('_', ' ').title()) + probabilities.append(round(prob * 100, 1)) + threshold = category_thresholds.get(label, 0.5) * 100 + statuses.append("DETECTED" if prob * 100 >= threshold else "Not Detected") + + # Sort by probability for the chart + chart_data = sorted(zip(categories, probabilities, statuses), key=lambda x: x[1], reverse=True) + chart_cats, chart_probs, chart_statuses = zip(*chart_data) + + # Two column layout for results + col1, col2 = st.columns([3, 2]) + + with col1: + # Card with overall result and detected categories + with stylable_container( + key="result_card", + css_styles=f""" + {{ + border-radius: 10px; + padding: 10px 15px; + background-color: {THEME["card_bg"]}; + border-left: 5px solid {result_color}; + margin-bottom: 10px; + box-shadow: 0 4px 12px rgba(0,0,0,0.1); + overflow: hidden; + }} + """ + ): + # Overall result with abbreviated display + st.markdown(f""" +
+

Analysis Result:

+ {result_text} +
+
+ Language: {lang_info['flag']} {lang_info['name']} {'(detected)' if prediction["detected"] else ''} +
+
+ Toxic Categories: {", ".join([f'{category.replace("_", " ").title()}' for category in results["toxic_categories"]]) if is_toxic and toxic_count > 0 else 'None'} +
+ """, unsafe_allow_html=True) + + # Add toxicity probability graph inside the result card + st.markdown("

Toxicity Probabilities:

", unsafe_allow_html=True) + + # Create a horizontal bar chart with Plotly + fig = go.Figure() + + # Add bars with different colors based on toxicity + for i, (cat, prob, status) in enumerate(zip(chart_cats, chart_probs, chart_statuses)): + color = THEME["toxic"] if status == "DETECTED" else THEME["non_toxic"] + border_color = hex_to_rgba(color, 0.85) # Using rgba for border + + fig.add_trace(go.Bar( + y=[cat], + x=[prob], + orientation='h', + name=cat, + marker=dict( + color=color, + line=dict( + color=border_color, + width=2 + ) + ), + text=[f"{prob}%"], + textposition='outside', + textfont=dict(size=16, weight='bold'), # Much larger, bold text + hoverinfo='text', + hovertext=[f"{cat}: {prob}%"] + )) + + # Update layout + fig.update_layout( + title=None, + xaxis_title="Probability (%)", + yaxis_title=None, # Remove y-axis title to save space + height=340, # Significantly increased height + margin=dict(l=10, r=40, t=20, b=40), # More margin space for labels + xaxis=dict( + range=[0, 115], # Extended for outside labels + gridcolor=hex_to_rgba(THEME["text"], 0.15), + zerolinecolor=hex_to_rgba(THEME["text"], 0.3), + color=THEME["text"], + tickfont=dict(size=15), # Larger tick font + title_font=dict(size=16, family="Space Grotesk, sans-serif") # Larger axis title + ), + yaxis=dict( + gridcolor=hex_to_rgba(THEME["text"], 0.15), + color=THEME["text"], + tickfont=dict(size=15, family="Space Grotesk, sans-serif", weight='bold'), # Larger, bold category names + automargin=True # Auto-adjust margin to fit category names + ), + bargap=0.3, # More space between bars + paper_bgcolor='rgba(0,0,0,0)', + plot_bgcolor='rgba(0,0,0,0)', + font=dict( + family="Space Grotesk, sans-serif", + color=THEME["text"], + size=15 # Larger base font size + ), + showlegend=False + ) + + # Grid lines + fig.update_xaxes( + showgrid=True, + gridwidth=1.5, # Slightly wider grid lines + gridcolor=hex_to_rgba(THEME["text"], 0.15), + dtick=20 + ) + + # Display the plot + st.plotly_chart(fig, use_container_width=True, config={ + 'displayModeBar': False, + 'displaylogo': False + }) + + with col2: + # Performance metrics card + if performance: + with stylable_container( + key="performance_metrics_card", + css_styles=f""" + {{ + border-radius: 10px; + padding: 20px; + background-color: {THEME["card_bg"]}; + border-left: 3px solid {THEME["primary"]}; + height: 100%; + box-shadow: 0 4px 12px rgba(0,0,0,0.1); + }} + """ + ): + st.markdown("

Performance Metrics

", unsafe_allow_html=True) + total_time = performance.get("total_time", 0) + inference_time = performance.get("model_inference_time", 0) + lang_detection_time = performance.get("lang_detection_time", 0) + + # Create tabs for different types of metrics + perf_tab1, perf_tab2 = st.tabs(["Time Metrics", "Resource Usage"]) + + with perf_tab1: + time_cols = st.columns(1) + with time_cols[0]: + # Use custom HTML metrics instead of st.metric + total_time_val = f"{total_time:.3f}s" + inference_time_val = f"{inference_time:.3f}s" + lang_detection_time_val = f"{lang_detection_time:.3f}s" + + st.markdown(f""" +
+
+ Total Time +
+
+ {total_time_val} +
+
+ +
+
+ Model Inference +
+
+ {inference_time_val} +
+
+ +
+
+ Language Detection +
+
+ {lang_detection_time_val} +
+
+ """, unsafe_allow_html=True) + + with perf_tab2: + # Display system resource metrics with custom HTML + current_sys_info = update_system_resources() + + # Format delta: add + sign for positive values + cpu_usage = current_sys_info["cpu"]["usage"] + cpu_delta = f"{resource_delta['cpu_usage']:+.1f}%" if abs(resource_delta['cpu_usage']) > 0.1 else None + cpu_delta_display = f" ({cpu_delta})" if cpu_delta else "" + + ram_usage = current_sys_info["ram"]["percent"] + ram_delta = f"{resource_delta['ram_usage']:+.1f}%" if abs(resource_delta['ram_usage']) > 0.1 else None + ram_delta_display = f" ({ram_delta})" if ram_delta else "" + + if DEVICE == "cuda": + gpu_memory = update_gpu_info() + memory_display = f"GPU Memory: {gpu_memory}" + else: + memory_display = f"System RAM: {current_sys_info['ram']['used']} / {current_sys_info['ram']['total']}" + + st.markdown(f""" +
+
+ CPU Usage +
+
+ {cpu_usage}{cpu_delta_display} +
+
+ +
+
+ RAM Usage +
+
+ {ram_usage}{ram_delta_display} +
+
+ +
+
+ Memory +
+
+ {memory_display} +
+
+ """, unsafe_allow_html=True) + else: + pass # Remove the info message + +# Bottom section with improved styling for usage guide +st.divider() +colored_header( + label="How to use this AI Model", + description="Follow these steps to analyze text for toxicity", + color_name="blue-70" +) + +# Steps with more engaging design +st.markdown(""" +
+
1
+
Enter text in the input box above. You can type directly or paste from another source.
+
+ +
+
2
+
Select a specific language from the sidebar or use the auto-detect feature if you're unsure.
+
+ +
+
3
+
Click "Analyze Text" to get detailed toxicity analysis results.
+
+ +
+
4
+
Examine the breakdown of toxicity categories, probabilities, and visualization.
+
+ +
+
5
+
Try different examples from the sidebar to see how the model performs with various languages.
+
+""", unsafe_allow_html=True) + +# Adding footer with credits and improved styling +st.markdown(""" + +""", unsafe_allow_html=True) \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..b3cdaf028f794135756af8474f24d424834ac328 --- /dev/null +++ b/train.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Basic configuration +export CUDA_VISIBLE_DEVICES="0,1" +export PYTHONWARNINGS="ignore" +export PYTHONPATH="${PYTHONPATH}:${PWD}" # Add current directory to Python path + +# Create directories +mkdir -p logs weights cache + +# Get timestamp for error log only +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +ERROR_LOG="logs/error_${TIMESTAMP}.log" + +# Print configuration +echo "Starting training with configuration:" +echo "======================================" +echo "Error log: $ERROR_LOG" +echo "PYTHONPATH: $PYTHONPATH" +echo "======================================" + +# Start training with nohup, only redirecting stderr +echo "Starting training in background..." +nohup python model/train.py 2> "$ERROR_LOG" & + +# Save process ID +pid=$! +echo $pid > "logs/train_${TIMESTAMP}.pid" +echo "Training process started with PID: $pid" +echo +echo "Monitor commands:" +echo "1. View error log: tail -f $ERROR_LOG" +echo "2. Check process status: ps -p $pid" +echo "3. Stop training: kill $pid" \ No newline at end of file diff --git a/utils/KBin_labeling.py b/utils/KBin_labeling.py new file mode 100644 index 0000000000000000000000000000000000000000..2887b3035795fa017e0929f425e7023c7075a2ea --- /dev/null +++ b/utils/KBin_labeling.py @@ -0,0 +1,191 @@ +import pandas as pd +import numpy as np +from scipy import stats +from sklearn.preprocessing import KBinsDiscretizer +import matplotlib.pyplot as plt +import os + +class ToxicityOrdinalEncoder: + def __init__(self, n_bins=4, strategy='quantile'): + self.n_bins = n_bins + self.strategy = strategy + self.bin_edges = {} + self.ordinal_mapping = {} + self.label_mapping = {} + + def _get_optimal_bins(self, values): + """Dynamically determine bins using statistical analysis""" + unique_vals = np.unique(values) + if len(unique_vals) <= self.n_bins: + return sorted(unique_vals) + + # Handle 1D data properly and check sample size + if len(values) < 2: + return np.linspace(0, 1, self.n_bins + 1) + + try: + # Transpose for correct KDE dimensions (d, N) = (1, samples) + kde = stats.gaussian_kde(values.T) + x = np.linspace(0, 1, 100) + minima = [] + for i in range(1, len(x)-1): + if (kde(x[i]) < kde(x[i-1])) and (kde(x[i]) < kde(x[i+1])): + minima.append(x[i]) + + if minima: + return [0] + sorted(minima) + [1] + except np.linalg.LinAlgError: + pass + + # Fallback to KBinsDiscretizer + est = KBinsDiscretizer(n_bins=self.n_bins, + encode='ordinal', + strategy=self.strategy) + est.fit(values) + return est.bin_edges_[0] + + def fit(self, df, columns): + """Learn optimal binning for each toxicity category""" + for col in columns: + # Filter and validate non-zero values + non_zero = df[col][df[col] > 0].values.reshape(-1, 1) + + # Handle empty columns + if len(non_zero) == 0: + self.bin_edges[col] = [0, 1] + self.ordinal_mapping[col] = {0: 0} + continue + + # Handle small sample sizes + if len(non_zero) < 2: + self.bin_edges[col] = np.linspace(0, 1, self.n_bins + 1) + continue + + bins = self._get_optimal_bins(non_zero) + self.bin_edges[col] = bins + + # Create ordinal mapping + self.ordinal_mapping[col] = { + val: i for i, val in enumerate(sorted(np.unique(bins))) + } + + # Create label mapping for interpretability + self.label_mapping[col] = { + 0: 'Non-toxic', + 1: 'Low', + 2: 'Medium', + 3: 'High', + 4: 'Severe' + } + + return self + + def transform(self, df, columns): + """Apply learned ordinal mapping with safety checks""" + transformed = df.copy() + + for col in columns: + if col not in self.bin_edges: + raise ValueError(f"Column {col} not fitted") + + bins = self.bin_edges[col] + transformed[col] = pd.cut(df[col], bins=bins, + labels=False, include_lowest=True) + + # Preserve zero as separate class + transformed[col] = np.where(df[col] == 0, 0, transformed[col] + 1) + transformed[col] = transformed[col].astype(int) # Ensure integer type + + return transformed + +def plot_toxicity_distribution(df, transformed_df, column, bin_edges, save_dir='images'): + """Plot original vs binned distribution for a toxicity column""" + plt.figure(figsize=(15, 6)) + + # Original distribution + plt.subplot(1, 2, 1) + non_zero_vals = df[column][df[column] > 0] + if len(non_zero_vals) > 0: + plt.hist(non_zero_vals, bins=50, alpha=0.7) + plt.title(f'Original {column.replace("_", " ").title()} Distribution\n(Non-zero values)') + plt.xlabel('Toxicity Score') + plt.ylabel('Count') + + # Add bin edges as vertical lines + for edge in bin_edges[column]: + plt.axvline(x=edge, color='r', linestyle='--', alpha=0.5) + else: + plt.text(0.5, 0.5, 'No non-zero values', ha='center', va='center') + + # Binned distribution + plt.subplot(1, 2, 2) + unique_bins = sorted(transformed_df[column].unique()) + plt.hist(transformed_df[column], bins=len(unique_bins), + range=(min(unique_bins)-0.5, max(unique_bins)+0.5), + alpha=0.7, rwidth=0.8) + plt.title(f'Binned {column.replace("_", " ").title()} Distribution') + plt.xlabel('Toxicity Level') + plt.ylabel('Count') + + # Add labels for toxicity levels + plt.xticks(range(5), ['Non-toxic', 'Low', 'Medium', 'High', 'Severe']) + + plt.tight_layout() + os.makedirs(save_dir, exist_ok=True) + plt.savefig(os.path.join(save_dir, f'{column}_distribution.png')) + plt.close() + +def main(): + # Load dataset + print("Loading dataset...") + input_file = 'dataset/raw/MULTILINGUAL_TOXIC_DATASET_367k_7LANG_cleaned.csv' + df = pd.read_csv(input_file) + + # Define toxicity columns + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Print initial value distributions + print("\nInitial value distributions:") + for col in toxicity_cols: + print(f"\n{col.replace('_', ' ').title()}:") + print(df[col].value_counts().sort_index()) + + # Initialize and fit encoder + print("\nFitting toxicity encoder...") + encoder = ToxicityOrdinalEncoder(n_bins=4) + encoder.fit(df, toxicity_cols) + + # Transform data + print("Transforming toxicity values...") + transformed_df = encoder.transform(df, toxicity_cols) + + # Plot distributions + print("\nGenerating distribution plots...") + for col in toxicity_cols: + plot_toxicity_distribution(df, transformed_df, col, encoder.bin_edges) + + # Print binning information + print("\nBin edges for each toxicity type:") + for col in toxicity_cols: + print(f"\n{col.replace('_', ' ').title()}:") + edges = encoder.bin_edges[col] + for i in range(len(edges)-1): + print(f"Level {encoder.label_mapping[col][i+1]}: {edges[i]:.3f} to {edges[i+1]:.3f}") + + # Save transformed dataset + output_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_binned.csv' + print(f"\nSaving binned dataset to: {output_file}") + transformed_df.to_csv(output_file, index=False) + + # Print final value distributions + print("\nFinal binned distributions:") + for col in toxicity_cols: + print(f"\n{col.replace('_', ' ').title()}:") + dist = transformed_df[col].value_counts().sort_index() + for level, count in dist.items(): + print(f"{encoder.label_mapping[col][level]}: {count:,} ({count/len(df)*100:.1f}%)") + +if __name__ == "__main__": + main() + + diff --git a/utils/add_ids.py b/utils/add_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..a28b1e3f20ee4646be0545438976aaba931f50cc --- /dev/null +++ b/utils/add_ids.py @@ -0,0 +1,78 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import os +import hashlib + +def generate_comment_id(row, toxicity_cols): + """Generate a unique ID encoding language and toxicity information""" + # Get toxicity type codes + tox_code = ''.join(['1' if row[col] > 0 else '0' for col in toxicity_cols]) + + # Create a hash of the comment text for uniqueness + text_hash = hashlib.md5(row['comment_text'].encode()).hexdigest()[:6] + + # Combine language, toxicity code, and hash + # Format: {lang}_{toxicity_code}_{hash} + # Example: en_100010_a1b2c3 (English comment with toxic and insult flags) + return f"{row['lang']}_{tox_code}_{text_hash}" + +def add_dataset_ids(input_file, output_file=None): + """Add meaningful IDs to the dataset""" + print(f"\nReading dataset: {input_file}") + df = pd.read_csv(input_file) + + # Initial stats + total_rows = len(df) + print(f"\nInitial dataset size: {total_rows:,} comments") + + # Toxicity columns in order + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + print("\nGenerating IDs...") + # Generate IDs for each row + df['id'] = df.apply(lambda row: generate_comment_id(row, toxicity_cols), axis=1) + + # Verify ID uniqueness + unique_ids = df['id'].nunique() + print(f"\nGenerated {unique_ids:,} unique IDs") + + if unique_ids < total_rows: + print(f"Warning: {total_rows - unique_ids:,} duplicate IDs found") + # Handle duplicates by adding a suffix + df['id'] = df.groupby('id').cumcount().astype(str) + '_' + df['id'] + print("Added suffixes to make IDs unique") + + # Print sample IDs for each language + print("\nSample IDs by language:") + print("-" * 50) + for lang in df['lang'].unique(): + lang_sample = df[df['lang'] == lang].sample(n=min(3, len(df[df['lang'] == lang])), random_state=42) + print(f"\n{lang.upper()}:") + for _, row in lang_sample.iterrows(): + tox_types = [col for col in toxicity_cols if row[col] > 0] + print(f"ID: {row['id']}") + print(f"Toxicity: {', '.join(tox_types) if tox_types else 'None'}") + print(f"Text: {row['comment_text'][:100]}...") + + # Move ID column to first position + cols = ['id'] + [col for col in df.columns if col != 'id'] + df = df[cols] + + # Save dataset with IDs + if output_file is None: + base, ext = os.path.splitext(input_file) + output_file = f"{base}_with_ids{ext}" + + os.makedirs(os.path.dirname(output_file), exist_ok=True) + print(f"\nSaving dataset with IDs to: {output_file}") + df.to_csv(output_file, index=False) + print(f"File size: {Path(output_file).stat().st_size / (1024*1024):.1f} MB") + + return df + +if __name__ == "__main__": + input_file = "dataset/raw/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_binary.csv" + output_file = "dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_binary_with_ids.csv" + + df_with_ids = add_dataset_ids(input_file, output_file) \ No newline at end of file diff --git a/utils/balance_classes.py b/utils/balance_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..2eaa08e619ed1f2930386e4c44ab18c84e3a76ee --- /dev/null +++ b/utils/balance_classes.py @@ -0,0 +1,159 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import json +import os +from googletrans import Translator +from tqdm import tqdm +import time + +def get_class_stats(df, lang, column): + """Calculate statistics for a specific class and language""" + lang_df = df[df['lang'] == lang] + total = int(len(lang_df)) + positive_count = int(lang_df[column].sum()) + return { + 'total': total, + 'positive_count': positive_count, + 'positive_ratio': float(positive_count / total if total > 0 else 0) + } + +def backtranslate_text(text, translator, intermediate_lang='fr'): + """Backtranslate text using an intermediate language""" + try: + # Add delay to avoid rate limiting + time.sleep(1) + # Translate to intermediate language + intermediate = translator.translate(text, dest=intermediate_lang).text + # Translate back to English + time.sleep(1) + back_to_en = translator.translate(intermediate, dest='en').text + return back_to_en + except Exception as e: + print(f"Translation error: {str(e)}") + return text + +def balance_dataset_distributions(input_dir='dataset/balanced', output_dir='dataset/final_balanced'): + """Balance Turkish toxic class and augment English identity hate samples""" + print("\n=== Balancing Dataset Distributions ===\n") + + # Create output directory + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # Load datasets + print("Loading datasets...") + train_df = pd.read_csv(os.path.join(input_dir, 'train_balanced.csv')) + val_df = pd.read_csv(os.path.join(input_dir, 'val_balanced.csv')) + test_df = pd.read_csv(os.path.join(input_dir, 'test_balanced.csv')) + + # 1. Fix Turkish Toxic Class Balance + print("\nInitial Turkish Toxic Distribution:") + print("-" * 50) + for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: + stats = get_class_stats(df, 'tr', 'toxic') + print(f"{name}: {stats['positive_count']}/{stats['total']} ({stats['positive_ratio']:.2%})") + + # Remove excess Turkish toxic samples from test + tr_test = test_df[test_df['lang'] == 'tr'] + target_ratio = get_class_stats(train_df, 'tr', 'toxic')['positive_ratio'] + current_ratio = get_class_stats(test_df, 'tr', 'toxic')['positive_ratio'] + + if current_ratio > target_ratio: + samples_to_remove = 150 # As specified + print(f"\nRemoving {samples_to_remove} Turkish toxic samples from test set...") + + # Identify and remove samples + np.random.seed(42) + tr_toxic_samples = test_df[ + (test_df['lang'] == 'tr') & + (test_df['toxic'] > 0) + ] + remove_idx = tr_toxic_samples.sample(n=samples_to_remove).index + test_df = test_df.drop(remove_idx) + + # 2. Augment English Identity Hate in Validation + print("\nInitial English Identity Hate Distribution:") + print("-" * 50) + for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: + stats = get_class_stats(df, 'en', 'identity_hate') + print(f"{name}: {stats['positive_count']}/{stats['total']} ({stats['positive_ratio']:.2%})") + + # Select samples for backtranslation + print("\nAugmenting English identity hate samples in validation set...") + en_train_hate = train_df[ + (train_df['lang'] == 'en') & + (train_df['identity_hate'] > 0) + ] + samples = en_train_hate.sample(n=50, replace=True, random_state=42) + + # Initialize translator + translator = Translator() + + # Perform backtranslation + print("Performing backtranslation (this may take a few minutes)...") + augmented_samples = [] + for _, row in tqdm(samples.iterrows(), total=len(samples)): + # Create new sample with backtranslated text + new_sample = row.copy() + new_sample['comment_text'] = backtranslate_text(row['comment_text'], translator) + augmented_samples.append(new_sample) + + # Add augmented samples to validation set + val_df = pd.concat([val_df, pd.DataFrame(augmented_samples)], ignore_index=True) + + # Save balanced datasets + print("\nSaving final balanced datasets...") + train_df.to_csv(os.path.join(output_dir, 'train_final.csv'), index=False) + val_df.to_csv(os.path.join(output_dir, 'val_final.csv'), index=False) + test_df.to_csv(os.path.join(output_dir, 'test_final.csv'), index=False) + + # Save balancing statistics + stats = { + 'turkish_toxic': { + 'original_distribution': { + 'train': get_class_stats(train_df, 'tr', 'toxic'), + 'val': get_class_stats(val_df, 'tr', 'toxic'), + 'test': get_class_stats(test_df, 'tr', 'toxic') + }, + 'samples_removed': 150 + }, + 'english_identity_hate': { + 'original_distribution': { + 'train': get_class_stats(train_df, 'en', 'identity_hate'), + 'val': get_class_stats(val_df, 'en', 'identity_hate'), + 'test': get_class_stats(test_df, 'en', 'identity_hate') + }, + 'samples_added': 50 + } + } + + with open(os.path.join(output_dir, 'balancing_stats.json'), 'w') as f: + json.dump(stats, f, indent=2) + + return train_df, val_df, test_df + +def validate_final_distributions(train_df, val_df, test_df): + """Validate the final distributions of all classes across languages""" + print("\nFinal Distribution Validation:") + print("-" * 50) + + classes = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + languages = sorted(train_df['lang'].unique()) + + for lang in languages: + print(f"\n{lang.upper()}:") + for class_name in classes: + print(f"\n {class_name.upper()}:") + for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: + stats = get_class_stats(df, lang, class_name) + print(f" {name}: {stats['positive_count']}/{stats['total']} ({stats['positive_ratio']:.2%})") + +if __name__ == "__main__": + # First install required package if not already installed + # !pip install googletrans==4.0.0-rc1 + + # Balance datasets + train_df, val_df, test_df = balance_dataset_distributions() + + # Validate final distributions + validate_final_distributions(train_df, val_df, test_df) \ No newline at end of file diff --git a/utils/calculate_weights.py b/utils/calculate_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..a62cbbe59bbfedd6e3be13b66ba833a9e0dbd14e --- /dev/null +++ b/utils/calculate_weights.py @@ -0,0 +1,129 @@ +import pandas as pd +import numpy as np +import json +from pathlib import Path +import os + +def calculate_class_weights(df, toxicity_cols): + """Calculate class weights using inverse frequency scaling""" + total_samples = len(df) + weights = {} + + # Calculate weights for each toxicity type + for col in toxicity_cols: + positive_count = (df[col] > 0).sum() + negative_count = total_samples - positive_count + + # Use balanced weights formula: n_samples / (n_classes * n_samples_for_class) + pos_weight = total_samples / (2 * positive_count) if positive_count > 0 else 0 + neg_weight = total_samples / (2 * negative_count) if negative_count > 0 else 0 + + weights[col] = { + 'positive_weight': pos_weight, + 'negative_weight': neg_weight, + 'positive_count': int(positive_count), + 'negative_count': int(negative_count), + 'positive_ratio': float(positive_count/total_samples), + 'negative_ratio': float(negative_count/total_samples) + } + + return weights + +def calculate_language_weights(df, toxicity_cols): + """Calculate class weights for each language""" + languages = df['lang'].unique() + language_weights = {} + + for lang in languages: + lang_df = df[df['lang'] == lang] + lang_weights = calculate_class_weights(lang_df, toxicity_cols) + language_weights[lang] = lang_weights + + return language_weights + +def normalize_weights(weights_dict, baseline_class='obscene'): + """Normalize weights relative to a baseline class""" + # Get the positive weight of baseline class + baseline_weight = None + for lang, lang_weights in weights_dict.items(): + if baseline_weight is None: + baseline_weight = lang_weights[baseline_class]['positive_weight'] + + normalized_weights = {} + for lang, lang_weights in weights_dict.items(): + normalized_weights[lang] = {} + for col, weights in lang_weights.items(): + normalized_weights[lang][col] = { + 'positive_weight': weights['positive_weight'] / baseline_weight, + 'negative_weight': weights['negative_weight'] / baseline_weight, + 'positive_count': weights['positive_count'], + 'negative_count': weights['negative_count'], + 'positive_ratio': weights['positive_ratio'], + 'negative_ratio': weights['negative_ratio'] + } + + return normalized_weights + +def generate_weights(input_file): + """Generate and save class weights for the dataset""" + print(f"\nReading dataset: {input_file}") + df = pd.read_csv(input_file) + + # Initial stats + total_rows = len(df) + print(f"\nTotal samples: {total_rows:,}") + + # Toxicity columns + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Calculate overall weights + print("\nCalculating overall weights...") + overall_weights = calculate_class_weights(df, toxicity_cols) + + # Calculate language-specific weights + print("\nCalculating language-specific weights...") + language_weights = calculate_language_weights(df, toxicity_cols) + + # Normalize weights + print("\nNormalizing weights...") + normalized_overall = normalize_weights({'overall': overall_weights})['overall'] + normalized_language = normalize_weights(language_weights) + + # Prepare weights dictionary + weights_dict = { + 'dataset_info': { + 'total_samples': total_rows, + 'n_languages': len(df['lang'].unique()), + 'languages': list(df['lang'].unique()) + }, + 'overall_weights': overall_weights, + 'normalized_overall_weights': normalized_overall, + 'language_weights': language_weights, + 'normalized_language_weights': normalized_language + } + + # Save weights + output_dir = "weights" + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, "class_weights.json") + + print(f"\nSaving weights to: {output_file}") + with open(output_file, 'w') as f: + json.dump(weights_dict, f, indent=2) + + # Print summary + print("\nWeight Summary (Normalized Overall):") + print("-" * 50) + for col in toxicity_cols: + pos_weight = normalized_overall[col]['positive_weight'] + pos_count = normalized_overall[col]['positive_count'] + pos_ratio = normalized_overall[col]['positive_ratio'] + print(f"\n{col.replace('_', ' ').title()}:") + print(f" Positive samples: {pos_count:,} ({pos_ratio*100:.2f}%)") + print(f" Weight: {pos_weight:.2f}x") + + return weights_dict + +if __name__ == "__main__": + input_file = "dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv" + weights = generate_weights(input_file) \ No newline at end of file diff --git a/utils/check_dataset.py b/utils/check_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb7653a3c610acd0809f2f842d153dbd76c174c --- /dev/null +++ b/utils/check_dataset.py @@ -0,0 +1,40 @@ +import pandas as pd + +def check_dataset(): + try: + # Check train dataset + print("\nChecking train dataset...") + train_df = pd.read_csv("dataset/split/train.csv") + print("\nTrain Dataset Columns:") + print("-" * 50) + for col in train_df.columns: + print(f"- {col}") + print(f"\nTrain Dataset Shape: {train_df.shape}") + print("\nTrain Dataset Info:") + print(train_df.info()) + print("\nFirst few rows of train dataset:") + print(train_df.head()) + + # Check validation dataset + print("\nChecking validation dataset...") + val_df = pd.read_csv("dataset/split/val.csv") + print("\nValidation Dataset Columns:") + print("-" * 50) + for col in val_df.columns: + print(f"- {col}") + print(f"\nValidation Dataset Shape: {val_df.shape}") + + # Check test dataset + print("\nChecking test dataset...") + test_df = pd.read_csv("dataset/split/test.csv") + print("\nTest Dataset Columns:") + print("-" * 50) + for col in test_df.columns: + print(f"- {col}") + print(f"\nTest Dataset Shape: {test_df.shape}") + + except Exception as e: + print(f"Error: {str(e)}") + +if __name__ == "__main__": + check_dataset() \ No newline at end of file diff --git a/utils/clean_labels.py b/utils/clean_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..7f454fb08fd3b23e1a59a595a9804b12953196c9 --- /dev/null +++ b/utils/clean_labels.py @@ -0,0 +1,73 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import os + +def clean_toxicity_labels(input_file, output_file=None): + """Clean toxicity labels by converting fractional values to binary using ceiling""" + print(f"\nReading dataset: {input_file}") + df = pd.read_csv(input_file) + + # Initial stats + total_rows = len(df) + print(f"\nInitial dataset size: {total_rows:,} comments") + + # Toxicity columns to clean + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Print initial value distribution + print("\nInitial value distribution:") + print("-" * 50) + for col in toxicity_cols: + unique_vals = df[col].value_counts().sort_index() + print(f"\n{col.replace('_', ' ').title()}:") + for val, count in unique_vals.items(): + print(f" {val}: {count:,} comments") + + # Clean each toxicity column + print("\nCleaning labels...") + for col in toxicity_cols: + # Get unique values before cleaning + unique_before = df[col].nunique() + non_binary = df[~df[col].isin([0, 1])][col].unique() + + if len(non_binary) > 0: + print(f"\n{col.replace('_', ' ').title()}:") + print(f" Found {len(non_binary)} non-binary values: {sorted(non_binary)}") + + # Convert to binary using ceiling (any value > 0 becomes 1) + df[col] = np.ceil(df[col]).clip(0, 1).astype(int) + + # Print conversion results + unique_after = df[col].nunique() + print(f" Unique values before: {unique_before}") + print(f" Unique values after: {unique_after}") + + # Print final value distribution + print("\nFinal value distribution:") + print("-" * 50) + for col in toxicity_cols: + value_counts = df[col].value_counts().sort_index() + total = len(df) + print(f"\n{col.replace('_', ' ').title()}:") + for val, count in value_counts.items(): + percentage = (count / total) * 100 + print(f" {val}: {count:,} comments ({percentage:.2f}%)") + + # Save cleaned dataset + if output_file is None: + base, ext = os.path.splitext(input_file) + output_file = f"{base}_cleaned{ext}" + + os.makedirs(os.path.dirname(output_file), exist_ok=True) + print(f"\nSaving cleaned dataset to: {output_file}") + df.to_csv(output_file, index=False) + print(f"File size: {Path(output_file).stat().st_size / (1024*1024):.1f} MB") + + return df + +if __name__ == "__main__": + input_file = "dataset/raw/MULTILINGUAL_TOXIC_DATASET_360K_7LANG.csv" + output_file = "dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_binary.csv" + + cleaned_df = clean_toxicity_labels(input_file, output_file) \ No newline at end of file diff --git a/utils/clean_text.py b/utils/clean_text.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1f1428731c2f71ecdd31cbe045ff6df05ca77a --- /dev/null +++ b/utils/clean_text.py @@ -0,0 +1,116 @@ +import pandas as pd +import re +from bs4 import BeautifulSoup +from tqdm import tqdm +import logging +from pathlib import Path + +def clean_text(text): + """Clean text by removing URLs, HTML tags, and special characters""" + try: + # Convert to string if not already + text = str(text) + + # Remove URLs + text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text) + + # Remove HTML tags + text = BeautifulSoup(text, "html.parser").get_text() + + # Remove special characters but keep basic punctuation + text = re.sub(r'[^\w\s.,!?-]', ' ', text) + + # Remove extra whitespace + text = ' '.join(text.split()) + + # Remove multiple punctuation + text = re.sub(r'([.,!?])\1+', r'\1', text) + + # Remove spaces before punctuation + text = re.sub(r'\s+([.,!?])', r'\1', text) + + return text.strip() + except Exception as e: + logging.error(f"Error cleaning text: {str(e)}") + return text + +def try_read_csv(file_path): + """Try different encodings to read the CSV file""" + encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] + + for encoding in encodings: + try: + print(f"Trying {encoding} encoding...") + return pd.read_csv(file_path, encoding=encoding) + except UnicodeDecodeError: + continue + except Exception as e: + print(f"Error with {encoding}: {str(e)}") + continue + + raise ValueError("Could not read file with any of the attempted encodings") + +def clean_dataset(input_path, output_path=None): + """Clean comment text in a dataset""" + print(f"\nReading input file: {input_path}") + + # If no output path specified, use input name with _cleaned suffix + if output_path is None: + output_path = str(Path(input_path).with_suffix('').with_name(f"{Path(input_path).stem}_cleaned.csv")) + + try: + # Try reading with different encodings + df = try_read_csv(input_path) + total_rows = len(df) + + print(f"\nDataset Info:") + print(f"Initial Rows: {total_rows:,}") + print(f"Columns: {', '.join(df.columns)}") + + # Verify 'comment_text' column exists + if 'comment_text' not in df.columns: + # Try to find a column that might contain the comments + text_columns = [col for col in df.columns if 'text' in col.lower() or 'comment' in col.lower()] + if text_columns: + print(f"\nUsing '{text_columns[0]}' as comment column") + df['comment_text'] = df[text_columns[0]] + else: + raise ValueError("Could not find comment text column") + + # Clean comment text with progress bar + print("\nCleaning comments...") + tqdm.pandas() + df['comment_text'] = df['comment_text'].progress_apply(clean_text) + + # Remove empty comments + non_empty_mask = df['comment_text'].str.strip().str.len() > 0 + df = df[non_empty_mask] + + # Save cleaned dataset + print(f"\nSaving to: {output_path}") + df.to_csv(output_path, index=False, encoding='utf-8') + + # Print statistics + print(f"\n✓ Successfully cleaned comments") + print(f"Initial rows: {total_rows:,}") + print(f"Final rows: {len(df):,}") + print(f"Removed empty rows: {total_rows - len(df):,}") + print(f"Output file: {output_path}") + print(f"Output file size: {Path(output_path).stat().st_size / (1024*1024):.1f} MB") + + # Sample of cleaned comments + print("\nSample of cleaned comments:") + for i, (orig, cleaned) in enumerate(zip(df['comment_text'].head(3), df['comment_text'].head(3))): + print(f"\nExample {i+1}:") + print(f"Original : {orig[:100]}...") + print(f"Cleaned : {cleaned[:100]}...") + + except Exception as e: + print(f"\n❌ Error: {str(e)}") + return + +if __name__ == "__main__": + input_path = "dataset/raw/english-trash.csv" + output_path = "dataset/raw/english-comments-cleaned.csv" + + clean_dataset(input_path, output_path) \ No newline at end of file diff --git a/utils/dataset_card.py b/utils/dataset_card.py new file mode 100644 index 0000000000000000000000000000000000000000..259ebe45cbf6479cf5de3e8b32cb8849bf983f3c --- /dev/null +++ b/utils/dataset_card.py @@ -0,0 +1,105 @@ +import pandas as pd +import os +from pathlib import Path +import json +from datetime import datetime + +def create_dataset_card(file_path): + """Create a dataset card with key information about the CSV file""" + try: + # Read the CSV file + df = pd.read_csv(file_path, encoding='utf-8') + + # Get file info + file_stats = os.stat(file_path) + file_size_mb = file_stats.st_size / (1024 * 1024) + last_modified = datetime.fromtimestamp(file_stats.st_mtime).strftime('%Y-%m-%d %H:%M:%S') + + # Create dataset card + card = { + "filename": Path(file_path).name, + "last_modified": last_modified, + "file_size_mb": round(file_size_mb, 2), + "num_rows": len(df), + "num_columns": len(df.columns), + "columns": list(df.columns), + "column_dtypes": df.dtypes.astype(str).to_dict(), + "null_counts": df.isnull().sum().to_dict(), + "sample_rows": df.head(3).to_dict('records') + } + + # Add language distribution if 'lang' column exists + if 'lang' in df.columns: + card["language_distribution"] = df['lang'].value_counts().to_dict() + + # Add label distribution if any toxic-related columns exist + toxic_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + label_stats = {} + for col in toxic_cols: + if col in df.columns: + label_stats[col] = df[col].value_counts().to_dict() + if label_stats: + card["label_distribution"] = label_stats + + return card + + except Exception as e: + return { + "filename": Path(file_path).name, + "error": str(e) + } + +def scan_dataset_directory(directory="dataset"): + """Scan directory for CSV files and create dataset cards""" + print(f"\nScanning directory: {directory}") + + # Find all CSV files + csv_files = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith('.csv'): + csv_files.append(os.path.join(root, file)) + + if not csv_files: + print("No CSV files found!") + return + + print(f"\nFound {len(csv_files)} CSV files") + + # Create dataset cards + cards = {} + for file_path in csv_files: + print(f"\nProcessing: {file_path}") + cards[file_path] = create_dataset_card(file_path) + + # Save to JSON file + output_file = "dataset/dataset_cards.json" + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(cards, f, indent=2, ensure_ascii=False) + + print(f"\n✓ Dataset cards saved to: {output_file}") + + # Print summary for each file + for file_path, card in cards.items(): + print(f"\n{'='*80}") + print(f"File: {card['filename']}") + if 'error' in card: + print(f"Error: {card['error']}") + continue + + print(f"Size: {card['file_size_mb']:.2f} MB") + print(f"Rows: {card['num_rows']:,}") + print(f"Columns: {', '.join(card['columns'])}") + + if 'language_distribution' in card: + print("\nLanguage Distribution:") + for lang, count in card['language_distribution'].items(): + print(f" {lang}: {count:,}") + + if 'label_distribution' in card: + print("\nLabel Distribution:") + for label, dist in card['label_distribution'].items(): + print(f" {label}: {dist}") + +if __name__ == "__main__": + scan_dataset_directory() \ No newline at end of file diff --git a/utils/extract_thresholds.py b/utils/extract_thresholds.py new file mode 100644 index 0000000000000000000000000000000000000000..670119f760afbccc0b6c1d27864d344e22702a3d --- /dev/null +++ b/utils/extract_thresholds.py @@ -0,0 +1,43 @@ +import json +import os +from pathlib import Path + +def extract_thresholds(eval_results_path: str, output_path: str = None) -> dict: + """ + Extract classification thresholds from evaluation results JSON file. + + Args: + eval_results_path (str): Path to the evaluation results JSON file + output_path (str, optional): Path to save the extracted thresholds. + If None, will save in the same directory as eval results + + Returns: + dict: Dictionary containing the extracted thresholds per language + """ + # Read evaluation results + with open(eval_results_path, 'r') as f: + results = json.load(f) + + # Extract thresholds + thresholds = results.get('thresholds', {}) + + # Save to file if output path provided + if output_path is None: + # Create thresholds file in same directory as eval results + eval_dir = os.path.dirname(eval_results_path) + output_path = os.path.join(eval_dir, 'thresholds.json') + + # Ensure directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Save with nice formatting + with open(output_path, 'w') as f: + json.dump(thresholds, f, indent=2) + + return thresholds + +if __name__ == '__main__': + # Example usage + eval_results_path = 'evaluation_results/eval_20250208_161149/evaluation_results.json' + thresholds = extract_thresholds(eval_results_path) + print("Thresholds extracted and saved successfully!") \ No newline at end of file diff --git a/utils/filter_toxic.py b/utils/filter_toxic.py new file mode 100644 index 0000000000000000000000000000000000000000..a0922c5baab58d31443c265b961b2c227b9ebb76 --- /dev/null +++ b/utils/filter_toxic.py @@ -0,0 +1,120 @@ +import pandas as pd +import os +import numpy as np + +def filter_and_balance_comments(input_file, output_file=None): + """Filter and balance dataset by maximizing toxic comments and matching with non-toxic""" + print(f"\nReading dataset: {input_file}") + df = pd.read_csv(input_file) + + # Initial stats + total_rows = len(df) + print(f"\nInitial dataset size: {total_rows:,} comments") + + # Toxicity columns + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Print initial toxicity distribution + print("\nInitial toxicity distribution:") + for col in toxicity_cols: + toxic_count = (df[col] > 0).sum() + print(f"{col.replace('_', ' ').title()}: {toxic_count:,} ({toxic_count/total_rows*100:.1f}%)") + + # Create mask for any toxicity + toxic_mask = df[toxicity_cols].any(axis=1) + + # Process each language separately to maintain balance + languages = df['lang'].unique() if 'lang' in df.columns else ['en'] + balanced_dfs = [] + + print("\nProcessing each language:") + for lang in languages: + print(f"\n{lang}:") + # If no lang column, use entire dataset + if 'lang' in df.columns: + lang_df = df[df['lang'] == lang] + else: + lang_df = df + + # Split into toxic and non-toxic + lang_toxic_df = lang_df[toxic_mask] if 'lang' in df.columns else lang_df[toxic_mask] + lang_non_toxic_df = lang_df[~toxic_mask] if 'lang' in df.columns else lang_df[~toxic_mask] + + toxic_count = len(lang_toxic_df) + non_toxic_count = len(lang_non_toxic_df) + + print(f"Total comments: {len(lang_df):,}") + print(f"Toxic comments available: {toxic_count:,}") + print(f"Non-toxic comments available: {non_toxic_count:,}") + + # Keep all toxic comments + sampled_toxic = lang_toxic_df + print(f"Kept all {toxic_count:,} toxic comments") + + # Sample equal number of non-toxic comments + if non_toxic_count >= toxic_count: + sampled_non_toxic = lang_non_toxic_df.sample(n=toxic_count, random_state=42) + print(f"Sampled {toxic_count:,} non-toxic comments to match") + else: + # If we have fewer non-toxic than toxic, use all non-toxic and sample additional with replacement + sampled_non_toxic = lang_non_toxic_df + additional_needed = toxic_count - non_toxic_count + if additional_needed > 0: + additional_samples = lang_non_toxic_df.sample(n=additional_needed, replace=True, random_state=42) + sampled_non_toxic = pd.concat([sampled_non_toxic, additional_samples], ignore_index=True) + print(f"Using all {non_toxic_count:,} non-toxic comments and added {additional_needed:,} resampled to balance") + + # Combine toxic and non-toxic for this language + lang_balanced = pd.concat([sampled_toxic, sampled_non_toxic], ignore_index=True) + print(f"Final language size: {len(lang_balanced):,} ({len(sampled_toxic):,} toxic, {len(sampled_non_toxic):,} non-toxic)") + balanced_dfs.append(lang_balanced) + + # Combine all balanced dataframes + balanced_df = pd.concat(balanced_dfs, ignore_index=True) + + # If we have more than target size, sample down + target_size = 51518 # Target size from the original requirement + if len(balanced_df) > target_size: + balanced_df = balanced_df.sample(n=target_size, random_state=42) + print(f"\nSampled down to {target_size:,} comments") + else: + print(f"\nKept all {len(balanced_df):,} comments (less than target size {target_size:,})") + + # Get final statistics + print("\nFinal dataset statistics:") + print(f"Total comments: {len(balanced_df):,}") + + if 'lang' in balanced_df.columns: + print("\nLanguage distribution in final dataset:") + lang_dist = balanced_df['lang'].value_counts() + for lang, count in lang_dist.items(): + toxic_in_lang = balanced_df[balanced_df['lang'] == lang][toxicity_cols].any(axis=1).sum() + print(f"{lang}: {count:,} comments ({toxic_in_lang:,} toxic, {count-toxic_in_lang:,} non-toxic)") + + print("\nToxicity distribution in final dataset:") + for col in toxicity_cols: + toxic_count = (balanced_df[col] > 0).sum() + print(f"{col.replace('_', ' ').title()}: {toxic_count:,} ({toxic_count/len(balanced_df)*100:.1f}%)") + + # Count comments with multiple toxicity types + toxic_counts = balanced_df[toxicity_cols].astype(bool).sum(axis=1) + print("\nComments by number of toxicity types:") + for n_toxic, count in toxic_counts.value_counts().sort_index().items(): + print(f"{n_toxic} type{'s' if n_toxic != 1 else ''}: {count:,} ({count/len(balanced_df)*100:.1f}%)") + + # Save balanced dataset + if output_file is None: + base, ext = os.path.splitext(input_file) + output_file = f"{base}_balanced{ext}" + + print(f"\nSaving balanced dataset to: {output_file}") + balanced_df.to_csv(output_file, index=False) + print(f"File size: {os.path.getsize(output_file) / (1024*1024):.1f} MB") + + return balanced_df + +if __name__ == "__main__": + input_file = "dataset/processed/english_merged.csv" + output_file = "dataset/processed/english_filtered.csv" + + filtered_df = filter_and_balance_comments(input_file, output_file) \ No newline at end of file diff --git a/utils/fix_pt_threat.py b/utils/fix_pt_threat.py new file mode 100644 index 0000000000000000000000000000000000000000..914312c9010a4e5eba27967372a98e8262c800a0 --- /dev/null +++ b/utils/fix_pt_threat.py @@ -0,0 +1,121 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import json +import os + +def get_threat_stats(df, lang='pt'): + """Calculate threat statistics for a given language""" + lang_df = df[df['lang'] == lang] + total = int(len(lang_df)) # Convert to native Python int + threat_count = int(lang_df['threat'].sum()) # Convert to native Python int + return { + 'total': total, + 'threat_count': threat_count, + 'threat_ratio': float(threat_count / total if total > 0 else 0) # Convert to native Python float + } + +def fix_pt_threat_distribution(input_dir='dataset/split', output_dir='dataset/balanced'): + """Fix Portuguese threat class overrepresentation while maintaining dataset balance""" + print("\n=== Fixing Portuguese Threat Distribution ===\n") + + # Create output directory + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # Load datasets + print("Loading datasets...") + train_df = pd.read_csv(os.path.join(input_dir, 'train.csv')) + val_df = pd.read_csv(os.path.join(input_dir, 'val.csv')) + test_df = pd.read_csv(os.path.join(input_dir, 'test.csv')) + + print("\nInitial Portuguese Threat Distribution:") + print("-" * 50) + for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: + stats = get_threat_stats(df) + print(f"{name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})") + + # Calculate target ratio based on train set + target_ratio = float(get_threat_stats(train_df)['threat_ratio']) # Convert to native Python float + print(f"\nTarget threat ratio (from train): {target_ratio:.2%}") + + # Fix test set distribution + pt_test = test_df[test_df['lang'] == 'pt'] + current_ratio = float(get_threat_stats(test_df)['threat_ratio']) # Convert to native Python float + + if current_ratio > target_ratio: + # Calculate how many samples to remove + current_threats = int(pt_test['threat'].sum()) # Convert to native Python int + target_threats = int(len(pt_test) * target_ratio) + samples_to_remove = int(current_threats - target_threats) + + print(f"\nRemoving {samples_to_remove} Portuguese threat samples from test set...") + + # Identify samples to remove + pt_threat_samples = test_df[ + (test_df['lang'] == 'pt') & + (test_df['threat'] > 0) + ] + + # Randomly select samples to remove + np.random.seed(42) # For reproducibility + remove_idx = np.random.choice( + pt_threat_samples.index, + size=samples_to_remove, + replace=False + ).tolist() # Convert to native Python list + + # Remove selected samples + test_df = test_df.drop(remove_idx) + + # Verify new distribution + new_ratio = float(get_threat_stats(test_df)['threat_ratio']) # Convert to native Python float + print(f"New Portuguese threat ratio: {new_ratio:.2%}") + + # Save statistics + stats = { + 'original_distribution': { + 'train': get_threat_stats(train_df), + 'val': get_threat_stats(val_df), + 'test': get_threat_stats(test_df) + }, + 'samples_removed': samples_to_remove, + 'target_ratio': target_ratio, + 'achieved_ratio': new_ratio + } + + with open(os.path.join(output_dir, 'pt_threat_fix_stats.json'), 'w') as f: + json.dump(stats, f, indent=2) + + # Save balanced datasets + print("\nSaving balanced datasets...") + train_df.to_csv(os.path.join(output_dir, 'train_balanced.csv'), index=False) + val_df.to_csv(os.path.join(output_dir, 'val_balanced.csv'), index=False) + test_df.to_csv(os.path.join(output_dir, 'test_balanced.csv'), index=False) + + print("\nFinal Portuguese Threat Distribution:") + print("-" * 50) + for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: + stats = get_threat_stats(df) + print(f"{name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})") + else: + print("\nNo fix needed - test set threat ratio is not higher than train") + + return train_df, val_df, test_df + +def validate_distributions(train_df, val_df, test_df): + """Validate the threat distributions across all languages""" + print("\nValidating Threat Distributions Across Languages:") + print("-" * 50) + + for lang in sorted(train_df['lang'].unique()): + print(f"\n{lang.upper()}:") + for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: + stats = get_threat_stats(df, lang) + print(f" {name}: {stats['threat_count']}/{stats['total']} ({stats['threat_ratio']:.2%})") + +if __name__ == "__main__": + # Fix Portuguese threat distribution + train_df, val_df, test_df = fix_pt_threat_distribution() + + # Validate distributions across all languages + validate_distributions(train_df, val_df, test_df) \ No newline at end of file diff --git a/utils/merge_and_compare.py b/utils/merge_and_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..1e719e58793fb80c590d9c94e7d246adc8df441b --- /dev/null +++ b/utils/merge_and_compare.py @@ -0,0 +1,107 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import os + +def load_dataset(file_path, encoding='utf-8'): + """Load dataset with fallback encodings""" + encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] + + for enc in encodings: + try: + return pd.read_csv(file_path, encoding=enc) + except UnicodeDecodeError: + continue + except Exception as e: + print(f"Error with {enc}: {str(e)}") + continue + + raise ValueError(f"Could not read {file_path} with any encoding") + +def print_dataset_stats(df, name="Dataset"): + """Print detailed statistics about a dataset""" + print(f"\n{name} Statistics:") + print(f"Total comments: {len(df):,}") + + if 'lang' in df.columns: + print("\nLanguage distribution:") + lang_dist = df['lang'].value_counts() + for lang, count in lang_dist.items(): + print(f"{lang}: {count:,} ({count/len(df)*100:.1f}%)") + + toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + print("\nToxicity distribution:") + for col in toxicity_cols: + if col in df.columns: + toxic_count = (df[col] > 0).sum() + print(f"{col.replace('_', ' ').title()}: {toxic_count:,} ({toxic_count/len(df)*100:.1f}%)") + + if all(col in df.columns for col in toxicity_cols): + toxic_mask = df[toxicity_cols].any(axis=1) + total_toxic = toxic_mask.sum() + print(f"\nTotal Toxic Comments: {total_toxic:,} ({total_toxic/len(df)*100:.1f}%)") + print(f"Total Non-Toxic Comments: {len(df)-total_toxic:,} ({(len(df)-total_toxic)/len(df)*100:.1f}%)") + +def merge_and_compare_datasets(): + """Merge filtered English with non-English data and compare with original""" + + # Define file paths + english_filtered = "dataset/raw/english_filtered.csv" + non_english = "dataset/raw/MULTILINGUAL_TOXIC_DATASET_347k_7LANG_non_english.csv" + original = "dataset/raw/MULTILINGUAL_TOXIC_DATASET_347K_7LANG.csv" + output_file = "dataset/processed/final_merged_dataset.csv" + + print("Loading datasets...") + + # Load English filtered dataset + print("\nLoading filtered English dataset...") + eng_df = load_dataset(english_filtered) + eng_df['lang'] = 'en' # Ensure language column exists + print_dataset_stats(eng_df, "Filtered English Dataset") + + # Load non-English dataset + print("\nLoading non-English dataset...") + non_eng_df = load_dataset(non_english) + print_dataset_stats(non_eng_df, "Non-English Dataset") + + # Merge datasets + print("\nMerging datasets...") + merged_df = pd.concat([eng_df, non_eng_df], ignore_index=True) + print_dataset_stats(merged_df, "Merged Dataset") + + # Load original dataset for comparison + print("\nLoading original dataset for comparison...") + original_df = load_dataset(original) + print_dataset_stats(original_df, "Original Dataset") + + # Compare datasets + print("\nComparison Summary:") + print(f"Original dataset size: {len(original_df):,}") + print(f"Merged dataset size: {len(merged_df):,}") + print(f"Difference: {len(merged_df) - len(original_df):,} comments") + + if 'lang' in merged_df.columns and 'lang' in original_df.columns: + print("\nLanguage Distribution Comparison:") + orig_lang = original_df['lang'].value_counts() + new_lang = merged_df['lang'].value_counts() + + all_langs = sorted(set(orig_lang.index) | set(new_lang.index)) + for lang in all_langs: + orig_count = orig_lang.get(lang, 0) + new_count = new_lang.get(lang, 0) + diff = new_count - orig_count + print(f"{lang}:") + print(f" Original: {orig_count:,}") + print(f" New: {new_count:,}") + print(f" Difference: {diff:,} ({diff/orig_count*100:.1f}% change)") + + # Save merged dataset + print(f"\nSaving merged dataset to: {output_file}") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + merged_df.to_csv(output_file, index=False) + print(f"File size: {Path(output_file).stat().st_size / (1024*1024):.1f} MB") + + return merged_df + +if __name__ == "__main__": + merged_df = merge_and_compare_datasets() \ No newline at end of file diff --git a/utils/merge_datasets.py b/utils/merge_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..5f00f12fc22e2c7a407e8f6222b30f06c7b0a424 --- /dev/null +++ b/utils/merge_datasets.py @@ -0,0 +1,75 @@ +import pandas as pd +from pathlib import Path +import logging +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s | %(message)s' +) +logger = logging.getLogger(__name__) + +def merge_datasets(): + """Merge augmented threat dataset with main dataset""" + try: + # Load main dataset + logger.info("Loading main dataset...") + main_df = pd.read_csv("dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv") + logger.info(f"Main dataset: {len(main_df):,} rows") + + # Load augmented dataset + augmented_path = Path("dataset/augmented") + latest_augmented = max(augmented_path.glob("threat_augmented_*.csv")) + logger.info(f"Loading augmented dataset: {latest_augmented.name}") + aug_df = pd.read_csv(latest_augmented) + logger.info(f"Augmented dataset: {len(aug_df):,} rows") + + # Standardize columns for augmented data + logger.info("Standardizing columns...") + aug_df_standardized = pd.DataFrame({ + 'comment_text': aug_df['text'], + 'toxic': 1, + 'severe_toxic': 0, + 'obscene': 0, + 'threat': 1, + 'insult': 0, + 'identity_hate': 0, + 'lang': 'en' + }) + + # Check for duplicates between datasets + logger.info("Checking for duplicates...") + combined_texts = pd.concat([main_df['comment_text'], aug_df_standardized['comment_text']]) + duplicates = combined_texts.duplicated(keep='first') + duplicate_count = duplicates[len(main_df):].sum() + logger.info(f"Found {duplicate_count} duplicates in augmented data") + + # Remove duplicates from augmented data + aug_df_standardized = aug_df_standardized[~duplicates[len(main_df):].values] + logger.info(f"Augmented dataset after duplicate removal: {len(aug_df_standardized):,} rows") + + # Merge datasets + merged_df = pd.concat([main_df, aug_df_standardized], ignore_index=True) + logger.info(f"Final merged dataset: {len(merged_df):,} rows") + + # Save merged dataset + output_path = f"dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv" + merged_df.to_csv(output_path, index=False) + logger.info(f"Saved merged dataset to: {output_path}") + + # Print statistics + logger.info("\nDataset Statistics:") + logger.info(f"Original samples: {len(main_df):,}") + logger.info(f"Added threat samples: {len(aug_df_standardized):,}") + logger.info(f"Total samples: {len(merged_df):,}") + logger.info(f"Threat samples in final dataset: {merged_df['threat'].sum():,}") + + return merged_df + + except Exception as e: + logger.error(f"Error merging datasets: {str(e)}") + raise + +if __name__ == "__main__": + merged_df = merge_datasets() \ No newline at end of file diff --git a/utils/merge_english.py b/utils/merge_english.py new file mode 100644 index 0000000000000000000000000000000000000000..d070535caa2ee74d927328480129b4f146fa093a --- /dev/null +++ b/utils/merge_english.py @@ -0,0 +1,90 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import os + +def load_dataset(file_path, encoding='utf-8'): + """Load dataset with fallback encodings""" + encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] + + if encoding != 'utf-8': + encodings.insert(0, encoding) # Try specified encoding first + + for enc in encodings: + try: + return pd.read_csv(file_path, encoding=enc) + except UnicodeDecodeError: + continue + except Exception as e: + print(f"Error with {enc}: {str(e)}") + continue + + raise ValueError(f"Could not read {file_path} with any encoding") + +def merge_english_comments(output_file=None): + """Merge English comments from multiple datasets""" + + # Define input files + multilingual_file = 'dataset/raw/MULTILINGUAL_TOXIC_DATASET_347K_7LANG.csv' + english_file = 'dataset/raw/english-comments-cleaned.csv' + + print("\nProcessing multilingual dataset...") + multi_df = load_dataset(multilingual_file) + # Extract English comments + multi_df = multi_df[multi_df['lang'] == 'en'].copy() + print(f"Found {len(multi_df):,} English comments in multilingual dataset") + + print("\nProcessing English cleaned dataset...") + eng_df = load_dataset(english_file) + print(f"Found {len(eng_df):,} comments in English dataset") + + # Ensure both dataframes have the same columns + required_cols = ['comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] + + # Handle multilingual dataset + if 'comment_text' not in multi_df.columns and 'text' in multi_df.columns: + multi_df['comment_text'] = multi_df['text'] + + # Add missing toxicity columns with 0s if they don't exist + for col in required_cols[1:]: # Skip comment_text + if col not in multi_df.columns: + multi_df[col] = 0 + if col not in eng_df.columns: + eng_df[col] = 0 + + # Keep only required columns + multi_df = multi_df[required_cols] + eng_df = eng_df[required_cols] + + # Merge datasets + print("\nMerging datasets...") + merged_df = pd.concat([multi_df, eng_df], ignore_index=True) + initial_count = len(merged_df) + print(f"Initial merged size: {initial_count:,} comments") + + # Remove exact duplicates + merged_df = merged_df.drop_duplicates(subset=['comment_text'], keep='first') + final_count = len(merged_df) + print(f"After removing duplicates: {final_count:,} comments") + print(f"Removed {initial_count - final_count:,} duplicates") + + # Print toxicity distribution + print("\nToxicity distribution in final dataset:") + for col in required_cols[1:]: + toxic_count = (merged_df[col] > 0).sum() + print(f"{col.replace('_', ' ').title()}: {toxic_count:,} ({toxic_count/final_count*100:.1f}%)") + + # Save merged dataset + if output_file is None: + output_file = "dataset/processed/english_merged.csv" + + os.makedirs(os.path.dirname(output_file), exist_ok=True) + print(f"\nSaving merged dataset to: {output_file}") + merged_df.to_csv(output_file, index=False) + print(f"File size: {Path(output_file).stat().st_size / (1024*1024):.1f} MB") + + return merged_df + +if __name__ == "__main__": + output_file = "dataset/processed/english_merged.csv" + merged_df = merge_english_comments(output_file) \ No newline at end of file diff --git a/utils/parquet_to_csv.py b/utils/parquet_to_csv.py new file mode 100644 index 0000000000000000000000000000000000000000..c710c153e974ff0b783f8a3408e25aeb2188e153 --- /dev/null +++ b/utils/parquet_to_csv.py @@ -0,0 +1,51 @@ +import pandas as pd +from pathlib import Path +import sys +from tqdm import tqdm + +def convert_parquet_to_csv(parquet_path, csv_path=None): + """Convert a parquet file to CSV with progress tracking""" + print(f"\nReading parquet file: {parquet_path}") + + # If no CSV path specified, use the same name with .csv extension + if csv_path is None: + csv_path = str(Path(parquet_path).with_suffix('.csv')) + + try: + # Read parquet file + df = pd.read_parquet(parquet_path) + total_rows = len(df) + + print(f"\nDataset Info:") + print(f"Rows: {total_rows:,}") + print(f"Columns: {', '.join(df.columns)}") + print(f"\nSaving to CSV: {csv_path}") + + # Save to CSV with progress bar + with tqdm(total=total_rows, desc="Converting") as pbar: + # Use chunksize for memory efficiency + chunk_size = 10000 + for i in range(0, total_rows, chunk_size): + end_idx = min(i + chunk_size, total_rows) + chunk = df.iloc[i:end_idx] + + # Write mode: 'w' for first chunk, 'a' for rest + mode = 'w' if i == 0 else 'a' + header = i == 0 # Only write header for first chunk + + chunk.to_csv(csv_path, mode=mode, header=header, index=False) + pbar.update(len(chunk)) + + print(f"\n✓ Successfully converted to CSV") + print(f"Output file size: {Path(csv_path).stat().st_size / (1024*1024):.1f} MB") + + except Exception as e: + print(f"\n❌ Error: {str(e)}") + sys.exit(1) + +if __name__ == "__main__": + + parquet_path = "dataset/raw/jigsaw-toxic-comment-train-processed-seqlen128_original .parquet" + csv_path = "dataset/raw/jigsaw-en-only-toxic-comment-train-processed-seqlen128_original.csv" + + convert_parquet_to_csv(parquet_path, csv_path) \ No newline at end of file diff --git a/utils/process_dataset.py b/utils/process_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6534cd321a04bc66e48c55e49b8f63b4a9d7b749 --- /dev/null +++ b/utils/process_dataset.py @@ -0,0 +1,113 @@ +import pandas as pd +import numpy as np +from text_preprocessor import TextPreprocessor +from tqdm import tqdm +import logging +from pathlib import Path +import time + +def process_dataset(input_path: str, output_path: str = None, batch_size: int = 1000): + """ + Process a dataset using the TextPreprocessor with efficient batch processing. + + Args: + input_path: Path to input CSV file + output_path: Path to save processed CSV file. If None, will use input name with _processed suffix + batch_size: Number of texts to process in each batch + """ + # Setup output path + if output_path is None: + input_path = Path(input_path) + output_path = input_path.parent / f"{input_path.stem}_processed{input_path.suffix}" + + # Initialize preprocessor + preprocessor = TextPreprocessor() + + print(f"\nProcessing dataset: {input_path}") + start_time = time.time() + + try: + # Read the dataset + print("Reading dataset...") + df = pd.read_csv(input_path) + total_rows = len(df) + print(f"Total rows: {total_rows:,}") + + # Process in batches with progress bar + print("\nProcessing text...") + + # Calculate number of batches + num_batches = (total_rows + batch_size - 1) // batch_size + + for i in tqdm(range(0, total_rows, batch_size), total=num_batches, desc="Processing batches"): + # Get batch + batch_start = i + batch_end = min(i + batch_size, total_rows) + + # Process each text in the batch + for idx in range(batch_start, batch_end): + text = df.loc[idx, 'comment_text'] + lang = df.loc[idx, 'lang'] if 'lang' in df.columns else 'en' + + # Process text + processed = preprocessor.preprocess_text( + text, + lang=lang, + clean_options={ + 'remove_stops': True, + 'remove_numbers': True, + 'remove_urls': True, + 'remove_emails': True, + 'remove_mentions': True, + 'remove_hashtags': True, + 'expand_contractions': True, + 'remove_accents': False, + 'min_word_length': 2 + }, + do_stemming=True + ) + + # Update the text directly + df.loc[idx, 'comment_text'] = processed + + # Optional: Print sample from first batch + if i == 0: + print("\nSample processing results:") + for j in range(min(3, batch_size)): + print(f"\nProcessed text {j+1}: {df.loc[j, 'comment_text'][:100]}...") + + # Save processed dataset + print(f"\nSaving processed dataset to: {output_path}") + df.to_csv(output_path, index=False) + + # Print statistics + end_time = time.time() + processing_time = end_time - start_time + + print("\nProcessing Complete!") + print("-" * 50) + print(f"Total rows processed: {total_rows:,}") + print(f"Processing time: {processing_time/60:.2f} minutes") + print(f"Average time per text: {processing_time/total_rows*1000:.2f} ms") + print(f"Output file size: {Path(output_path).stat().st_size/1024/1024:.1f} MB") + + # Print sample of unique words before and after + print("\nVocabulary Statistics:") + sample_size = min(1000, total_rows) + original_words = set(' '.join(df['comment_text'].head(sample_size).astype(str)).split()) + processed_words = set(' '.join(df['processed_text'].head(sample_size).astype(str)).split()) + print(f"Sample unique words (first {sample_size:,} rows):") + print(f"Before processing: {len(original_words):,}") + print(f"After processing : {len(processed_words):,}") + print(f"Reduction: {(1 - len(processed_words)/len(original_words))*100:.1f}%") + + except Exception as e: + print(f"\nError processing dataset: {str(e)}") + raise + +if __name__ == "__main__": + # Process training dataset + input_file = "dataset/split/train.csv" + output_file = "dataset/split/train_no_stopwords.csv" + + process_dataset(input_file, output_file) \ No newline at end of file diff --git a/utils/remove_english.py b/utils/remove_english.py new file mode 100644 index 0000000000000000000000000000000000000000..f43c78f2fb6728f39880bde87df2304e0392d790 --- /dev/null +++ b/utils/remove_english.py @@ -0,0 +1,49 @@ +import pandas as pd +from pathlib import Path +import sys +from tqdm import tqdm + +def remove_english_comments(input_path, output_path=None): + """Remove English comments from a dataset with progress tracking""" + print(f"\nReading input file: {input_path}") + + # If no output path specified, use input name with _non_english suffix + if output_path is None: + output_path = str(Path(input_path).with_suffix('').with_name(f"{Path(input_path).stem}_non_english.csv")) + + try: + # Read input file with UTF-8 encoding + df = pd.read_csv(input_path, encoding='utf-8') + total_rows = len(df) + + print(f"\nDataset Info:") + print(f"Initial Rows: {total_rows:,}") + print(f"Columns: {', '.join(df.columns)}") + + # Filter out English comments (where lang == 'en') + print("\nFiltering out English comments...") + non_english_df = df[df['lang'] != 'en'] + + # Save to CSV with UTF-8 encoding + print(f"\nSaving to: {output_path}") + non_english_df.to_csv(output_path, index=False, encoding='utf-8') + + # Get statistics + english_rows = total_rows - len(non_english_df) + + print(f"\n✓ Successfully removed English comments") + print(f"Initial rows: {total_rows:,}") + print(f"Remaining non-English rows: {len(non_english_df):,}") + print(f"Removed English rows: {english_rows:,}") + print(f"Output file: {output_path}") + print(f"Output file size: {Path(output_path).stat().st_size / (1024*1024):.1f} MB") + + except Exception as e: + print(f"\n❌ Error: {str(e)}") + sys.exit(1) + +if __name__ == "__main__": + input_path = "dataset/raw/MULTILINGUAL_TOXIC_DATASET_347k_7LANG.csv" + output_path = input_path.replace(".csv", "_non_english.csv") + + remove_english_comments(input_path, output_path) \ No newline at end of file diff --git a/utils/remove_leakage.py b/utils/remove_leakage.py new file mode 100644 index 0000000000000000000000000000000000000000..07b526a844f01c277d6ce97047526cd34afae3d8 --- /dev/null +++ b/utils/remove_leakage.py @@ -0,0 +1,116 @@ +import pandas as pd +import hashlib +import os +from collections import defaultdict +from pathlib import Path + +def text_hash(text): + """Create a hash of the text after basic normalization""" + # Convert to string and normalize + text = str(text).strip().lower() + # Remove extra whitespace + text = ' '.join(text.split()) + # Create hash + return hashlib.sha256(text.encode()).hexdigest() + +def remove_leaked_samples(train_path, val_path, test_path, output_dir='dataset/clean'): + """Remove overlapping samples between dataset splits""" + print("\n=== Removing Data Leakage ===\n") + + # Create hash registry + hash_registry = defaultdict(set) + splits = {} + original_sizes = {} + + # Create output directory + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # Load datasets + print("Loading datasets...") + splits = { + 'train': pd.read_csv(train_path), + 'val': pd.read_csv(val_path), + 'test': pd.read_csv(test_path) + } + + # Store original sizes + for split_name, df in splits.items(): + original_sizes[split_name] = len(df) + print(f"Original {split_name} size: {len(df):,} samples") + + # Process each split + print("\nChecking for overlaps...") + removed_counts = defaultdict(int) + + for split_name, df in splits.items(): + print(f"\nProcessing {split_name} split...") + + # Calculate hashes for current split + current_hashes = set(df['comment_text'].apply(text_hash)) + hash_registry[split_name] = current_hashes + + # Check overlaps with other splits + for other_split in splits: + if other_split != split_name: + if hash_registry[other_split]: # Only check if other split is processed + overlaps = current_hashes & hash_registry[other_split] + if overlaps: + print(f" Found {len(overlaps):,} overlaps with {other_split}") + # Remove overlapping samples + df = df[~df['comment_text'].apply(text_hash).isin(overlaps)] + removed_counts[f"{split_name}_from_{other_split}"] = len(overlaps) + + # Update splits dictionary with cleaned dataframe + splits[split_name] = df + + # Save cleaned splits + print("\nSaving cleaned datasets...") + for split_name, df in splits.items(): + output_path = os.path.join(output_dir, f"{split_name}_clean.csv") + df.to_csv(output_path, index=False) + reduction = ((original_sizes[split_name] - len(df)) / original_sizes[split_name]) * 100 + print(f"Cleaned {split_name}: {len(df):,} samples (-{reduction:.2f}%)") + + # Print detailed overlap statistics + print("\nDetailed Overlap Statistics:") + print("-" * 50) + for overlap_type, count in removed_counts.items(): + split_name, other_split = overlap_type.split('_from_') + print(f"{split_name} → {other_split}: {count:,} overlapping samples removed") + + return splits + +def validate_cleaning(splits): + """Validate that no overlaps remain between splits""" + print("\nValidating Cleaning...") + print("-" * 50) + + all_clean = True + for split1 in splits: + for split2 in splits: + if split1 < split2: # Check each pair only once + hashes1 = set(splits[split1]['comment_text'].apply(text_hash)) + hashes2 = set(splits[split2]['comment_text'].apply(text_hash)) + overlaps = hashes1 & hashes2 + if overlaps: + print(f"⚠️ Warning: Found {len(overlaps)} overlaps between {split1} and {split2}") + all_clean = False + else: + print(f"✅ No overlaps between {split1} and {split2}") + + if all_clean: + print("\n✅ All splits are now clean with no overlaps!") + else: + print("\n⚠️ Some overlaps still remain. Consider additional cleaning.") + +if __name__ == "__main__": + # Define paths + train_path = "dataset/split/train.csv" + val_path = "dataset/split/val.csv" + test_path = "dataset/split/test.csv" + + # Remove leaked samples + cleaned_splits = remove_leaked_samples(train_path, val_path, test_path) + + # Validate cleaning + validate_cleaning(cleaned_splits) \ No newline at end of file diff --git a/utils/shuffle_dataset.py b/utils/shuffle_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d7d220c7b44c1b58f893884c8fced45053d5f1 --- /dev/null +++ b/utils/shuffle_dataset.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +""" +Thoroughly shuffle the dataset while maintaining class distributions and data integrity. +This script implements stratified shuffling to ensure balanced representation of classes +and languages in the shuffled data. +""" + +import pandas as pd +import numpy as np +from pathlib import Path +import argparse +from sklearn.model_selection import StratifiedKFold +from collections import defaultdict +import logging +import json +from typing import List, Dict, Tuple +import sys +from datetime import datetime + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler(f'logs/shuffle_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') + ] +) +logger = logging.getLogger(__name__) + +def create_stratification_label(row: pd.Series, toxicity_labels: List[str]) -> str: + """ + Create a composite label for stratification that captures the combination of + toxicity labels and language. + """ + # Convert toxicity values to binary string + toxicity_str = ''.join(['1' if row[label] == 1 else '0' for label in toxicity_labels]) + # Combine with language + return f"{row['lang']}_{toxicity_str}" + +def validate_data(df: pd.DataFrame, toxicity_labels: List[str]) -> bool: + """ + Validate the dataset for required columns and data integrity. + """ + try: + # Check required columns + required_columns = ['comment_text', 'lang'] + toxicity_labels + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise ValueError(f"Missing required columns: {missing_columns}") + + # Check for null values in critical columns + null_counts = df[required_columns].isnull().sum() + if null_counts.any(): + logger.warning(f"Found null values:\n{null_counts[null_counts > 0]}") + + # Validate label values are binary + for label in toxicity_labels: + invalid_values = df[label][~df[label].isin([0, 1, np.nan])] + if not invalid_values.empty: + raise ValueError(f"Found non-binary values in {label}: {invalid_values.unique()}") + + # Validate text content + if df['comment_text'].str.len().min() == 0: + logger.warning("Found empty comments in dataset") + + return True + + except Exception as e: + logger.error(f"Data validation failed: {str(e)}") + return False + +def analyze_distribution(df: pd.DataFrame, toxicity_labels: List[str]) -> Dict: + """ + Analyze the class distribution and language distribution in the dataset. + """ + stats = { + 'total_samples': len(df), + 'language_distribution': df['lang'].value_counts().to_dict(), + 'class_distribution': { + label: { + 'positive': int(df[label].sum()), + 'negative': int(len(df) - df[label].sum()), + 'ratio': float(df[label].mean()) + } + for label in toxicity_labels + }, + 'language_class_distribution': defaultdict(dict) + } + + # Calculate per-language class distributions + for lang in df['lang'].unique(): + lang_df = df[df['lang'] == lang] + stats['language_class_distribution'][lang] = { + label: { + 'positive': int(lang_df[label].sum()), + 'negative': int(len(lang_df) - lang_df[label].sum()), + 'ratio': float(lang_df[label].mean()) + } + for label in toxicity_labels + } + + return stats + +def shuffle_dataset( + input_file: str, + output_file: str, + toxicity_labels: List[str], + n_splits: int = 10, + random_state: int = 42 +) -> Tuple[bool, Dict]: + """ + Thoroughly shuffle the dataset while maintaining class distributions. + Uses stratified k-fold splitting for balanced shuffling. + """ + try: + logger.info(f"Loading dataset from {input_file}") + df = pd.read_csv(input_file) + + # Validate data + if not validate_data(df, toxicity_labels): + return False, {} + + # Analyze initial distribution + initial_stats = analyze_distribution(df, toxicity_labels) + logger.info("Initial distribution stats:") + logger.info(json.dumps(initial_stats, indent=2)) + + # Create stratification labels + logger.info("Creating stratification labels") + df['strat_label'] = df.apply( + lambda row: create_stratification_label(row, toxicity_labels), + axis=1 + ) + + # Initialize stratified k-fold + skf = StratifiedKFold( + n_splits=n_splits, + shuffle=True, + random_state=random_state + ) + + # Get shuffled indices using stratified split + logger.info(f"Performing stratified shuffling with {n_splits} splits") + all_indices = [] + for _, fold_indices in skf.split(df, df['strat_label']): + all_indices.extend(fold_indices) + + # Create shuffled dataframe + shuffled_df = df.iloc[all_indices].copy() + shuffled_df = shuffled_df.drop('strat_label', axis=1) + + # Analyze final distribution + final_stats = analyze_distribution(shuffled_df, toxicity_labels) + + # Save shuffled dataset + logger.info(f"Saving shuffled dataset to {output_file}") + shuffled_df.to_csv(output_file, index=False) + + # Save distribution statistics + stats_file = Path(output_file).parent / 'shuffle_stats.json' + stats = { + 'initial': initial_stats, + 'final': final_stats, + 'shuffle_params': { + 'n_splits': n_splits, + 'random_state': random_state + } + } + with open(stats_file, 'w') as f: + json.dump(stats, f, indent=2) + + logger.info(f"Shuffling complete. Statistics saved to {stats_file}") + return True, stats + + except Exception as e: + logger.error(f"Error shuffling dataset: {str(e)}") + return False, {} + +def main(): + parser = argparse.ArgumentParser(description='Thoroughly shuffle the dataset.') + parser.add_argument( + '--input', + type=str, + required=True, + help='Input CSV file path' + ) + parser.add_argument( + '--output', + type=str, + required=True, + help='Output CSV file path' + ) + parser.add_argument( + '--splits', + type=int, + default=10, + help='Number of splits for stratified shuffling (default: 10)' + ) + parser.add_argument( + '--seed', + type=int, + default=42, + help='Random seed (default: 42)' + ) + args = parser.parse_args() + + # Create output directory if it doesn't exist + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + # Create logs directory if it doesn't exist + Path('logs').mkdir(exist_ok=True) + + # Define toxicity labels + toxicity_labels = [ + 'toxic', 'severe_toxic', 'obscene', 'threat', + 'insult', 'identity_hate' + ] + + # Shuffle dataset + success, stats = shuffle_dataset( + args.input, + args.output, + toxicity_labels, + args.splits, + args.seed + ) + + if success: + logger.info("Dataset shuffling completed successfully") + # Print final class distribution + for label, dist in stats['final']['class_distribution'].items(): + logger.info(f"{label}: {dist['ratio']:.3f} " + f"(+:{dist['positive']}, -:{dist['negative']})") + else: + logger.error("Dataset shuffling failed") + sys.exit(1) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils/split_dataset.py b/utils/split_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..49ed22a42ad3e1a762af6682f0c8242a32ee670a --- /dev/null +++ b/utils/split_dataset.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 +import pandas as pd +import numpy as np +from sklearn.model_selection import StratifiedKFold +from pathlib import Path +import json +from collections import defaultdict +import logging +from typing import Dict, Tuple, Set +import time +from itertools import combinations +import hashlib +from tqdm import tqdm + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + +TOXICITY_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] +RARE_CLASSES = ['threat', 'identity_hate'] +MIN_SAMPLES_PER_CLASS = 1000 # Minimum samples required per class per language + +def create_multilabel_stratification_labels(row: pd.Series) -> str: + """ + Create composite labels that preserve multi-label patterns and language distribution. + Uses iterative label combination to capture co-occurrence patterns. + """ + # Create base label from language + label = str(row['lang']) + + # Add individual class information + for col in TOXICITY_COLUMNS: + label += '_' + str(int(row[col])) + + # Add co-occurrence patterns for pairs of classes + for c1, c2 in combinations(RARE_CLASSES, 2): + co_occur = int(row[c1] == 1 and row[c2] == 1) + label += '_' + str(co_occur) + + return label + +def oversample_rare_classes(df: pd.DataFrame) -> pd.DataFrame: + """ + Perform intelligent oversampling of rare classes while maintaining language distribution. + """ + oversampled_dfs = [] + original_df = df.copy() + + # Process each language separately + for lang in df['lang'].unique(): + lang_df = df[df['lang'] == lang] + + for rare_class in RARE_CLASSES: + class_samples = lang_df[lang_df[rare_class] == 1] + target_samples = MIN_SAMPLES_PER_CLASS + + if len(class_samples) < target_samples: + # Calculate number of samples needed + n_samples = target_samples - len(class_samples) + + # Oversample with small random variations + noise = np.random.normal(0, 0.1, (n_samples, len(TOXICITY_COLUMNS))) + oversampled = class_samples.sample(n_samples, replace=True) + + # Add noise to continuous values while keeping binary values intact + for col in TOXICITY_COLUMNS: + if col in [rare_class] + [c for c in RARE_CLASSES if c != rare_class]: + continue # Preserve original binary values for rare classes + oversampled[col] = np.clip( + oversampled[col].values + noise[:, TOXICITY_COLUMNS.index(col)], + 0, 1 + ) + + oversampled_dfs.append(oversampled) + + if oversampled_dfs: + return pd.concat([original_df] + oversampled_dfs, axis=0).reset_index(drop=True) + return original_df + +def verify_distributions( + original_df: pd.DataFrame, + train_df: pd.DataFrame, + val_df: pd.DataFrame, + test_df: pd.DataFrame = None +) -> Dict: + """ + Enhanced verification of distributions across splits with detailed metrics. + """ + splits = { + 'original': original_df, + 'train': train_df, + 'val': val_df + } + if test_df is not None: + splits['test'] = test_df + + stats = defaultdict(dict) + + for split_name, df in splits.items(): + # Language distribution + stats[split_name]['language_dist'] = df['lang'].value_counts(normalize=True).to_dict() + + # Per-language class distributions + lang_class_dist = {} + for lang in df['lang'].unique(): + lang_df = df[df['lang'] == lang] + lang_class_dist[lang] = { + col: { + 'positive_ratio': lang_df[col].mean(), + 'count': int(lang_df[col].sum()), + 'total': len(lang_df) + } for col in TOXICITY_COLUMNS + } + stats[split_name]['lang_class_dist'] = lang_class_dist + + # Multi-label co-occurrence patterns + cooccurrence = {} + for c1, c2 in combinations(TOXICITY_COLUMNS, 2): + cooccur_count = ((df[c1] == 1) & (df[c2] == 1)).sum() + cooccurrence[f"{c1}_{c2}"] = { + 'count': int(cooccur_count), + 'ratio': float(cooccur_count) / len(df) + } + stats[split_name]['cooccurrence_patterns'] = cooccurrence + + # Distribution deltas from original + if split_name != 'original': + deltas = {} + for lang in df['lang'].unique(): + for col in TOXICITY_COLUMNS: + orig_ratio = splits['original'][splits['original']['lang'] == lang][col].mean() + split_ratio = df[df['lang'] == lang][col].mean() + deltas[f"{lang}_{col}"] = abs(orig_ratio - split_ratio) + stats[split_name]['distribution_deltas'] = deltas + + return stats + +def check_contamination( + train_df: pd.DataFrame, + val_df: pd.DataFrame, + test_df: pd.DataFrame = None +) -> Dict: + """ + Enhanced contamination check including text similarity detection. + """ + # Determine the correct text column name + text_column = 'comment_text' if 'comment_text' in train_df.columns else 'text' + if text_column not in train_df.columns: + logging.warning("No text column found for contamination check. Skipping text-based contamination detection.") + return {'exact_matches': {'train_val': 0.0}} + + def get_text_hash_set(df: pd.DataFrame) -> Set[str]: + return set(df[text_column].str.lower().str.strip().values) + + contamination = { + 'exact_matches': { + 'train_val': len(get_text_hash_set(train_df) & get_text_hash_set(val_df)) / len(train_df) + } + } + + if test_df is not None: + contamination['exact_matches'].update({ + 'train_test': len(get_text_hash_set(train_df) & get_text_hash_set(test_df)) / len(train_df), + 'val_test': len(get_text_hash_set(val_df) & get_text_hash_set(test_df)) / len(val_df) + }) + + return contamination + +def split_dataset( + df: pd.DataFrame, + seed: int, + split_mode: str +) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Perform stratified splitting of the dataset. + """ + # Create stratification labels + logging.info("Creating stratification labels...") + stratify_labels = df.apply(create_multilabel_stratification_labels, axis=1) + + # Oversample rare classes in training data only + logging.info("Oversampling rare classes...") + df_with_oversampling = oversample_rare_classes(df) + + # Initialize splits + if split_mode == '3': + # First split: 80% train, 20% temp + splitter = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed) + train_idx, temp_idx = next(splitter.split(df, stratify_labels)) + + # Second split: 10% val, 10% test from temp + temp_df = df.iloc[temp_idx] + temp_labels = stratify_labels.iloc[temp_idx] + + splitter = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed) + val_idx, test_idx = next(splitter.split(temp_df, temp_labels)) + + # Create final splits + train_df = df_with_oversampling.iloc[train_idx] # Use oversampled data for training + val_df = df.iloc[temp_idx].iloc[val_idx] # Use original data for validation + test_df = df.iloc[temp_idx].iloc[test_idx] # Use original data for testing + + else: # 2-way split + splitter = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) + train_idx, val_idx = next(splitter.split(df, stratify_labels)) + + train_df = df_with_oversampling.iloc[train_idx] # Use oversampled data for training + val_df = df.iloc[val_idx] # Use original data for validation + test_df = None + + return train_df, val_df, test_df + +def save_splits( + train_df: pd.DataFrame, + val_df: pd.DataFrame, + test_df: pd.DataFrame, + output_dir: str, + stats: Dict +) -> None: + """ + Save splits and statistics to files. + """ + # Create output directory + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Save splits + logging.info("Saving splits...") + train_df.to_csv(output_path / 'train.csv', index=False) + val_df.to_csv(output_path / 'val.csv', index=False) + if test_df is not None: + test_df.to_csv(output_path / 'test.csv', index=False) + + # Save statistics + with open(output_path / 'stats.json', 'w', encoding='utf-8') as f: + json.dump(stats, f, indent=2, ensure_ascii=False) + +def compute_text_hash(text: str) -> str: + """ + Compute SHA-256 hash of normalized text. + """ + # Normalize text by removing extra whitespace and converting to lowercase + normalized = ' '.join(str(text).lower().split()) + return hashlib.sha256(normalized.encode('utf-8')).hexdigest() + +def deduplicate_dataset(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict]: + """ + Remove duplicates using cryptographic hashing while preserving metadata. + """ + logging.info("Starting cryptographic deduplication...") + + # Determine text column + text_column = 'comment_text' if 'comment_text' in df.columns else 'text' + if text_column not in df.columns: + raise ValueError(f"No text column found. Available columns: {df.columns}") + + # Compute hashes with progress bar + logging.info("Computing cryptographic hashes...") + tqdm.pandas(desc="Hashing texts") + df['text_hash'] = df[text_column].progress_apply(compute_text_hash) + + # Get duplicate statistics before removal + total_samples = len(df) + duplicate_hashes = df[df.duplicated('text_hash', keep=False)]['text_hash'].unique() + duplicate_groups = { + hash_val: df[df['text_hash'] == hash_val].index.tolist() + for hash_val in duplicate_hashes + } + + # Keep first occurrence of each text while tracking duplicates + dedup_df = df.drop_duplicates('text_hash', keep='first').copy() + dedup_df = dedup_df.drop('text_hash', axis=1) + + # Compile deduplication statistics + dedup_stats = { + 'total_samples': total_samples, + 'unique_samples': len(dedup_df), + 'duplicates_removed': total_samples - len(dedup_df), + 'duplicate_rate': (total_samples - len(dedup_df)) / total_samples, + 'duplicate_groups': { + str(k): { + 'count': len(v), + 'indices': v + } + for k, v in duplicate_groups.items() + } + } + + logging.info(f"Removed {dedup_stats['duplicates_removed']:,} duplicates " + f"({dedup_stats['duplicate_rate']:.2%} of dataset)") + + return dedup_df, dedup_stats + +def main(): + input_csv = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_AUGMENTED.csv' + output_dir = 'dataset/split' + seed = 42 + split_mode = '3' + + start_time = time.time() + + # Load dataset + logging.info(f"Loading dataset from {input_csv}...") + df = pd.read_csv(input_csv) + + # Print column names for debugging + logging.info(f"Available columns: {', '.join(df.columns)}") + + # Verify required columns + required_columns = ['lang'] + TOXICITY_COLUMNS + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise ValueError(f"Missing required columns: {missing_columns}") + + # Perform deduplication + df, dedup_stats = deduplicate_dataset(df) + + # Perform splitting + logging.info("Performing stratified split...") + train_df, val_df, test_df = split_dataset(df, seed, split_mode) + + # Verify distributions + logging.info("Verifying distributions...") + stats = verify_distributions(df, train_df, val_df, test_df) + + # Add deduplication stats + stats['deduplication'] = dedup_stats + + # Check contamination + logging.info("Checking for contamination...") + contamination = check_contamination(train_df, val_df, test_df) + stats['contamination'] = contamination + + # Save everything + logging.info(f"Saving splits to {output_dir}...") + save_splits(train_df, val_df, test_df, output_dir, stats) + + elapsed_time = time.time() - start_time + logging.info(f"Done! Elapsed time: {elapsed_time:.2f} seconds") + + # Print summary + print("\nDeduplication Summary:") + print("-" * 50) + print(f"Original samples: {dedup_stats['total_samples']:,}") + print(f"Unique samples: {dedup_stats['unique_samples']:,}") + print(f"Duplicates removed: {dedup_stats['duplicates_removed']:,} ({dedup_stats['duplicate_rate']:.2%})") + + print("\nSplit Summary:") + print("-" * 50) + print(f"Total samples: {len(df):,}") + print(f"Train samples: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)") + print(f"Validation samples: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)") + if test_df is not None: + print(f"Test samples: {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)") + print("\nDetailed statistics saved to stats.json") + +if __name__ == "__main__": + main() diff --git a/utils/text_preprocessor.py b/utils/text_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..47a2cce15e9084c3f28e525ed855982cc283182d --- /dev/null +++ b/utils/text_preprocessor.py @@ -0,0 +1,285 @@ +import re +import nltk +import logging +from typing import List, Set, Dict, Optional +from nltk.tokenize import word_tokenize +from nltk.corpus import stopwords +from nltk.stem import SnowballStemmer +from TurkishStemmer import TurkishStemmer +from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning +import unicodedata +import warnings + +# Suppress BeautifulSoup warning about markup resembling a filename +warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning) + +# Download required NLTK data +try: + nltk.download('stopwords', quiet=True) + nltk.download('punkt', quiet=True) + nltk.download('punkt_tab', quiet=True) + nltk.download('averaged_perceptron_tagger', quiet=True) +except Exception as e: + print(f"Warning: Could not download NLTK data: {str(e)}") + +# Configure logging +logging.basicConfig(level=logging.WARNING) + +class TextPreprocessor: + """ + A comprehensive text preprocessor for multilingual text cleaning and normalization. + Supports multiple languages and provides various text cleaning operations. + """ + + SUPPORTED_LANGUAGES = {'en', 'es', 'fr', 'it', 'pt', 'ru', 'tr'} + + # Common contractions mapping (can be extended) + CONTRACTIONS = { + "ain't": "is not", "aren't": "are not", "can't": "cannot", + "couldn't": "could not", "didn't": "did not", "doesn't": "does not", + "don't": "do not", "hadn't": "had not", "hasn't": "has not", + "haven't": "have not", "he'd": "he would", "he'll": "he will", + "he's": "he is", "i'd": "i would", "i'll": "i will", "i'm": "i am", + "i've": "i have", "isn't": "is not", "it's": "it is", + "let's": "let us", "shouldn't": "should not", "that's": "that is", + "there's": "there is", "they'd": "they would", "they'll": "they will", + "they're": "they are", "they've": "they have", "wasn't": "was not", + "we'd": "we would", "we're": "we are", "we've": "we have", + "weren't": "were not", "what's": "what is", "where's": "where is", + "who's": "who is", "won't": "will not", "wouldn't": "would not", + "you'd": "you would", "you'll": "you will", "you're": "you are", + "you've": "you have" + } + + def __init__(self, languages: Optional[Set[str]] = None): + """ + Initialize the text preprocessor with specified languages. + + Args: + languages: Set of language codes to support. If None, all supported languages are used. + """ + self.languages = languages or self.SUPPORTED_LANGUAGES + self._initialize_resources() + + def _initialize_resources(self): + """Initialize language-specific resources like stop words and stemmers.""" + # Initialize logging + self.logger = logging.getLogger(__name__) + + # Initialize stop words for each language + self.stop_words = {} + nltk_langs = { + 'en': 'english', 'es': 'spanish', 'fr': 'french', + 'it': 'italian', 'pt': 'portuguese', 'ru': 'russian' + } + + for lang, nltk_name in nltk_langs.items(): + if lang in self.languages: + try: + self.stop_words[lang] = set(stopwords.words(nltk_name)) + except Exception as e: + self.logger.warning(f"Could not load stop words for {lang}: {str(e)}") + self.stop_words[lang] = set() + + # Add Turkish stop words manually + if 'tr' in self.languages: + self.stop_words['tr'] = { + 'acaba', 'ama', 'aslında', 'az', 'bazı', 'belki', 'biri', 'birkaç', + 'birşey', 'biz', 'bu', 'çok', 'çünkü', 'da', 'daha', 'de', 'defa', + 'diye', 'eğer', 'en', 'gibi', 'hem', 'hep', 'hepsi', 'her', 'hiç', + 'için', 'ile', 'ise', 'kez', 'ki', 'kim', 'mı', 'mu', 'mü', 'nasıl', + 'ne', 'neden', 'nerde', 'nerede', 'nereye', 'niçin', 'niye', 'o', + 'sanki', 'şey', 'siz', 'şu', 'tüm', 've', 'veya', 'ya', 'yani' + } + + # Initialize stemmers + self.stemmers = {} + for lang, name in [ + ('en', 'english'), ('es', 'spanish'), ('fr', 'french'), + ('it', 'italian'), ('pt', 'portuguese'), ('ru', 'russian') + ]: + if lang in self.languages: + self.stemmers[lang] = SnowballStemmer(name) + + # Initialize Turkish stemmer separately + if 'tr' in self.languages: + self.stemmers['tr'] = TurkishStemmer() + + def remove_html(self, text: str) -> str: + """Remove HTML tags from text.""" + return BeautifulSoup(text, "html.parser").get_text() + + def expand_contractions(self, text: str) -> str: + """Expand contractions in English text.""" + for contraction, expansion in self.CONTRACTIONS.items(): + text = re.sub(rf'\b{contraction}\b', expansion, text, flags=re.IGNORECASE) + return text + + def remove_accents(self, text: str) -> str: + """Remove accents from text while preserving base characters.""" + return ''.join(c for c in unicodedata.normalize('NFKD', text) + if not unicodedata.combining(c)) + + def clean_text(self, text: str, lang: str = 'en', + remove_stops: bool = True, + remove_numbers: bool = True, + remove_urls: bool = True, + remove_emails: bool = True, + remove_mentions: bool = True, + remove_hashtags: bool = True, + expand_contractions: bool = True, + remove_accents: bool = False, + min_word_length: int = 2) -> str: + """ + Clean and normalize text with configurable options. + + Args: + text: Input text to clean + lang: Language code of the text + remove_stops: Whether to remove stop words + remove_numbers: Whether to remove numbers + remove_urls: Whether to remove URLs + remove_emails: Whether to remove email addresses + remove_mentions: Whether to remove social media mentions + remove_hashtags: Whether to remove hashtags + expand_contractions: Whether to expand contractions (English only) + remove_accents: Whether to remove accents from characters + min_word_length: Minimum length of words to keep + + Returns: + Cleaned text string + """ + try: + # Convert to string and lowercase + text = str(text).lower().strip() + + # Remove HTML tags if any HTML-like content is detected + if '<' in text and '>' in text: + text = self.remove_html(text) + + # Remove URLs if requested + if remove_urls: + text = re.sub(r'http\S+|www\S+', '', text) + + # Remove email addresses if requested + if remove_emails: + text = re.sub(r'\S+@\S+', '', text) + + # Remove mentions if requested + if remove_mentions: + text = re.sub(r'@\w+', '', text) + + # Remove hashtags if requested + if remove_hashtags: + text = re.sub(r'#\w+', '', text) + + # Remove numbers if requested + if remove_numbers: + text = re.sub(r'\d+', '', text) + + # Expand contractions for English text + if lang == 'en' and expand_contractions: + text = self.expand_contractions(text) + + # Remove accents if requested + if remove_accents: + text = self.remove_accents(text) + + # Language-specific character cleaning + if lang == 'tr': + text = re.sub(r'[^a-zA-ZçğıöşüÇĞİÖŞÜ\s]', '', text) + elif lang == 'ru': + text = re.sub(r'[^а-яА-Я\s]', '', text) + else: + text = re.sub(r'[^\w\s]', '', text) + + # Simple word splitting as fallback if tokenization fails + try: + words = word_tokenize(text) + except Exception as e: + self.logger.debug(f"Word tokenization failed, falling back to simple split: {str(e)}") + words = text.split() + + # Remove stop words if requested + if remove_stops and lang in self.stop_words: + words = [w for w in words if w not in self.stop_words[lang]] + + # Remove short words + words = [w for w in words if len(w) > min_word_length] + + # Rejoin words + return ' '.join(words) + + except Exception as e: + self.logger.warning(f"Error in text cleaning: {str(e)}") + return text + + def stem_text(self, text: str, lang: str = 'en') -> str: + """ + Apply language-specific stemming to text. + + Args: + text: Input text to stem + lang: Language code of the text + + Returns: + Stemmed text string + """ + try: + if lang not in self.stemmers: + return text + + words = text.split() + stemmed_words = [self.stemmers[lang].stem(word) for word in words] + return ' '.join(stemmed_words) + + except Exception as e: + self.logger.warning(f"Error in text stemming: {str(e)}") + return text + + def preprocess_text(self, text: str, lang: str = 'en', + clean_options: Dict = None, + do_stemming: bool = True) -> str: + """ + Complete preprocessing pipeline combining cleaning and stemming. + + Args: + text: Input text to preprocess + lang: Language code of the text + clean_options: Dictionary of options to pass to clean_text + do_stemming: Whether to apply stemming + + Returns: + Preprocessed text string + """ + # Use default cleaning options if none provided + clean_options = clean_options or {} + + # Clean text + cleaned_text = self.clean_text(text, lang, **clean_options) + + # Apply stemming if requested + if do_stemming: + cleaned_text = self.stem_text(cleaned_text, lang) + + return cleaned_text.strip() + +# Usage example +if __name__ == "__main__": + # Initialize preprocessor + preprocessor = TextPreprocessor() + + # Example texts in different languages + examples = { + 'en': "Here's an example! This is a test text with @mentions and #hashtags http://example.com", + 'es': "¡Hola! Este es un ejemplo de texto en español con números 12345", + 'fr': "Voici un exemple de texte en français avec des accents é è à", + 'tr': "Bu bir Türkçe örnek metindir ve bazı özel karakterler içerir." + } + + # Process each example + for lang, text in examples.items(): + print(f"\nProcessing {lang} text:") + print("Original:", text) + processed = preprocessor.preprocess_text(text, lang) + print("Processed:", processed) \ No newline at end of file