|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import os |
|
import json |
|
from model.inference_optimized import OptimizedToxicityClassifier |
|
import matplotlib.pyplot as plt |
|
from typing import List, Dict |
|
import langid |
|
import pandas as pd |
|
|
|
|
|
ONNX_MODEL_PATH = os.environ.get("ONNX_MODEL_PATH", "weights/toxic_classifier.onnx") |
|
PYTORCH_MODEL_PATH = os.environ.get("PYTORCH_MODEL_PATH", "weights/toxic_classifier_xlm-roberta-large/pytorch_model.bin") |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
SUPPORTED_LANGUAGES = { |
|
'en': 'English', |
|
'ru': 'Russian', |
|
'tr': 'Turkish', |
|
'es': 'Spanish', |
|
'fr': 'French', |
|
'it': 'Italian', |
|
'pt': 'Portuguese' |
|
} |
|
|
|
|
|
try: |
|
if os.path.exists(ONNX_MODEL_PATH): |
|
classifier = OptimizedToxicityClassifier(onnx_path=ONNX_MODEL_PATH, device=DEVICE) |
|
print(f"Loaded ONNX model from {ONNX_MODEL_PATH}") |
|
else: |
|
classifier = OptimizedToxicityClassifier(pytorch_path=PYTORCH_MODEL_PATH, device=DEVICE) |
|
print(f"Loaded PyTorch model from {PYTORCH_MODEL_PATH}") |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
classifier = None |
|
|
|
def detect_language(text: str) -> str: |
|
"""Detect language of input text""" |
|
try: |
|
lang, _ = langid.classify(text) |
|
return lang if lang in SUPPORTED_LANGUAGES else 'en' |
|
except: |
|
return 'en' |
|
|
|
def predict_toxicity(text: str, selected_language: str = None) -> Dict: |
|
"""Predict toxicity of input text""" |
|
if not text or not text.strip(): |
|
return { |
|
"error": "Please enter some text to analyze.", |
|
"html_result": "<div class='error'>Please enter some text to analyze.</div>" |
|
} |
|
|
|
if classifier is None: |
|
return { |
|
"error": "Model not loaded. Please check logs.", |
|
"html_result": "<div class='error'>Model not loaded. Please check logs.</div>" |
|
} |
|
|
|
|
|
if not selected_language or selected_language == "Auto-detect": |
|
lang_code = detect_language(text) |
|
detected = True |
|
else: |
|
|
|
lang_code = next((code for code, name in SUPPORTED_LANGUAGES.items() |
|
if name == selected_language), 'en') |
|
detected = False |
|
|
|
|
|
try: |
|
results = classifier.predict([text], langs=[lang_code])[0] |
|
|
|
|
|
probs = results["probabilities"] |
|
sorted_categories = sorted( |
|
[(label, probs[label]) for label in probs], |
|
key=lambda x: x[1], |
|
reverse=True |
|
) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
labels = [label.replace('_', ' ').title() for label, _ in sorted_categories] |
|
values = [prob * 100 for _, prob in sorted_categories] |
|
colors = ['#ff6b6b' if val >= 50 else '#74c0fc' for val in values] |
|
|
|
ax.barh(labels, values, color=colors) |
|
ax.set_xlim(0, 100) |
|
ax.set_xlabel('Probability (%)') |
|
ax.set_title('Toxicity Analysis') |
|
ax.grid(axis='x', linestyle='--', alpha=0.7) |
|
|
|
|
|
for i, v in enumerate(values): |
|
ax.text(v + 1, i, f'{v:.1f}%', va='center') |
|
|
|
|
|
lang_display = SUPPORTED_LANGUAGES.get(lang_code, lang_code) |
|
overall_result = "TOXIC" if results["is_toxic"] else "NON-TOXIC" |
|
result_color = "#ff6b6b" if results["is_toxic"] else "#66d9e8" |
|
|
|
html_result = f""" |
|
<div style='margin-bottom: 20px;'> |
|
<h2>Analysis Result: <span style='color: {result_color};'>{overall_result}</span></h2> |
|
<h3>Language: {lang_display} {'(detected)' if detected else ''}</h3> |
|
</div> |
|
<div style='margin-bottom: 10px;'> |
|
<table width='100%' style='border-collapse: collapse;'> |
|
<tr style='background-color: #e9ecef; font-weight: bold;'> |
|
<th style='padding: 8px; text-align: left; border: 1px solid #dee2e6;'>Category</th> |
|
<th style='padding: 8px; text-align: right; border: 1px solid #dee2e6;'>Probability</th> |
|
<th style='padding: 8px; text-align: center; border: 1px solid #dee2e6;'>Status</th> |
|
</tr> |
|
""" |
|
|
|
|
|
for label, prob in sorted_categories: |
|
formatted_label = label.replace('_', ' ').title() |
|
status = "DETECTED" if prob >= 0.5 else "Not Detected" |
|
status_color = "#ff6b6b" if prob >= 0.5 else "#66d9e8" |
|
prob_percent = f"{prob * 100:.1f}%" |
|
|
|
html_result += f""" |
|
<tr> |
|
<td style='padding: 8px; border: 1px solid #dee2e6;'>{formatted_label}</td> |
|
<td style='padding: 8px; text-align: right; border: 1px solid #dee2e6;'>{prob_percent}</td> |
|
<td style='padding: 8px; text-align: center; border: 1px solid #dee2e6; color: {status_color}; font-weight: bold;'>{status}</td> |
|
</tr> |
|
""" |
|
|
|
html_result += "</table></div>" |
|
|
|
|
|
if results["is_toxic"]: |
|
toxic_categories = [cat.replace('_', ' ').title() for cat in results["toxic_categories"]] |
|
categories_list = ", ".join(toxic_categories) |
|
html_result += f""" |
|
<div style='margin-top: 10px;'> |
|
<p><strong>Detected toxic categories:</strong> {categories_list}</p> |
|
</div> |
|
""" |
|
|
|
return { |
|
"prediction": results, |
|
"html_result": html_result, |
|
"fig": fig |
|
} |
|
|
|
except Exception as e: |
|
import traceback |
|
traceback.print_exc() |
|
return { |
|
"error": f"Error processing text: {str(e)}", |
|
"html_result": f"<div class='error'>Error processing text: {str(e)}</div>" |
|
} |
|
|
|
def create_app(): |
|
"""Create and configure the Gradio interface""" |
|
|
|
language_options = ["Auto-detect"] + list(SUPPORTED_LANGUAGES.values()) |
|
|
|
|
|
with gr.Blocks(css=""" |
|
.error { color: #ff6b6b; font-weight: bold; padding: 10px; border: 1px solid #ff6b6b; } |
|
.container { margin: 0 auto; max-width: 900px; } |
|
.gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } |
|
.example-text { font-style: italic; color: #666; } |
|
""") as app: |
|
gr.Markdown(""" |
|
# Multilingual Toxic Comment Classifier |
|
This app analyzes text for different types of toxicity across multiple languages. |
|
Enter your text, select a language (or let it auto-detect), and click 'Analyze'. |
|
|
|
Supported languages: English, Russian, Turkish, Spanish, French, Italian, Portuguese |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text_input = gr.Textbox( |
|
label="Enter text to analyze", |
|
placeholder="Type or paste text here...", |
|
lines=5 |
|
) |
|
lang_dropdown = gr.Dropdown( |
|
choices=language_options, |
|
value="Auto-detect", |
|
label="Language" |
|
) |
|
analyze_btn = gr.Button("Analyze", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("### Example texts:") |
|
with gr.Accordion("English example"): |
|
en_example_btn = gr.Button("Use English example") |
|
with gr.Accordion("Spanish example"): |
|
es_example_btn = gr.Button("Use Spanish example") |
|
with gr.Accordion("French example"): |
|
fr_example_btn = gr.Button("Use French example") |
|
|
|
|
|
en_example_text = "You are such an idiot, nobody likes your stupid content." |
|
es_example_text = "Eres un completo idiota y nadie te quiere." |
|
fr_example_text = "Tu es tellement stupide, personne n'aime ton contenu minable." |
|
|
|
en_example_btn.click( |
|
lambda: en_example_text, |
|
outputs=text_input |
|
) |
|
es_example_btn.click( |
|
lambda: es_example_text, |
|
outputs=text_input |
|
) |
|
fr_example_btn.click( |
|
lambda: fr_example_text, |
|
outputs=text_input |
|
) |
|
|
|
|
|
result_html = gr.HTML(label="Analysis Result") |
|
plot_output = gr.Plot(label="Toxicity Probabilities") |
|
|
|
|
|
analyze_btn.click( |
|
predict_toxicity, |
|
inputs=[text_input, lang_dropdown], |
|
outputs=[result_html, plot_output] |
|
) |
|
|
|
|
|
text_input.submit( |
|
predict_toxicity, |
|
inputs=[text_input, lang_dropdown], |
|
outputs=[result_html, plot_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
### About this model |
|
This model classifies text into six toxicity categories: |
|
- **Toxic**: General toxicity |
|
- **Severe Toxic**: Extreme toxicity |
|
- **Obscene**: Obscene content |
|
- **Threat**: Threatening content |
|
- **Insult**: Insulting content |
|
- **Identity Hate**: Identity-based hate |
|
|
|
Built using XLM-RoBERTa with language-aware fine-tuning. |
|
""") |
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
app = create_app() |
|
app.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |