|
import torch |
|
from model.language_aware_transformer import LanguageAwareTransformer |
|
from transformers import XLMRobertaTokenizer |
|
import os |
|
import re |
|
import json |
|
from pathlib import Path |
|
import logging |
|
from langdetect import detect, DetectorFactory |
|
from langdetect.lang_detect_exception import LangDetectException |
|
import sys |
|
import locale |
|
import io |
|
|
|
|
|
if sys.platform == 'win32': |
|
|
|
import msvcrt |
|
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') |
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') |
|
|
|
os.system('chcp 65001') |
|
else: |
|
|
|
if sys.stdout.encoding != 'utf-8': |
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') |
|
if sys.stdin.encoding != 'utf-8': |
|
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DetectorFactory.seed = 0 |
|
|
|
SUPPORTED_LANGUAGES = { |
|
'en': 0, 'ru': 1, 'tr': 2, 'es': 3, |
|
'fr': 4, 'it': 5, 'pt': 6 |
|
} |
|
|
|
|
|
DEFAULT_THRESHOLDS = { |
|
'toxic': 0.80, |
|
'severe_toxic': 0.45, |
|
'obscene': 0.48, |
|
'threat': 0.42, |
|
'insult': 0.70, |
|
'identity_hate': 0.43 |
|
} |
|
|
|
|
|
UNICODE_RANGES = { |
|
'ru': [ |
|
(0x0400, 0x04FF), |
|
(0x0500, 0x052F), |
|
], |
|
'tr': [ |
|
(0x011E, 0x011F), |
|
(0x0130, 0x0131), |
|
(0x015E, 0x015F), |
|
], |
|
'es': [ |
|
(0x00C1, 0x00C1), |
|
(0x00C9, 0x00C9), |
|
(0x00CD, 0x00CD), |
|
(0x00D1, 0x00D1), |
|
(0x00D3, 0x00D3), |
|
(0x00DA, 0x00DA), |
|
(0x00DC, 0x00DC), |
|
], |
|
'fr': [ |
|
(0x00C0, 0x00C6), |
|
(0x00C8, 0x00CB), |
|
(0x00CC, 0x00CF), |
|
(0x00D2, 0x00D6), |
|
(0x0152, 0x0153), |
|
], |
|
'it': [ |
|
(0x00C0, 0x00C0), |
|
(0x00C8, 0x00C8), |
|
(0x00C9, 0x00C9), |
|
(0x00CC, 0x00CC), |
|
(0x00D2, 0x00D2), |
|
(0x00D9, 0x00D9), |
|
], |
|
'pt': [ |
|
(0x00C0, 0x00C3), |
|
(0x00C7, 0x00C7), |
|
(0x00C9, 0x00CA), |
|
(0x00D3, 0x00D5), |
|
] |
|
} |
|
|
|
def load_model(model_path): |
|
"""Load the trained model and tokenizer""" |
|
try: |
|
|
|
model_dir = Path(model_path).absolute() |
|
|
|
if model_dir.is_dir(): |
|
|
|
latest_link = model_dir / 'latest' |
|
if latest_link.exists() and latest_link.is_symlink(): |
|
|
|
target = latest_link.readlink() |
|
|
|
if target.is_absolute(): |
|
model_dir = target |
|
else: |
|
|
|
model_dir = (latest_link.parent / target).resolve() |
|
logger.info(f"Using latest checkpoint: {model_dir}") |
|
else: |
|
|
|
checkpoints = sorted([ |
|
d for d in model_dir.iterdir() |
|
if d.is_dir() and d.name.startswith('checkpoint_epoch') |
|
]) |
|
if checkpoints: |
|
model_dir = checkpoints[-1] |
|
logger.info(f"Using most recent checkpoint: {model_dir}") |
|
else: |
|
logger.info("No checkpoints found, using base directory") |
|
|
|
logger.info(f"Loading model from: {model_dir}") |
|
|
|
|
|
if not model_dir.exists(): |
|
raise FileNotFoundError(f"Model directory not found: {model_dir}") |
|
|
|
|
|
model = LanguageAwareTransformer( |
|
num_labels=6, |
|
hidden_size=1024, |
|
num_attention_heads=16, |
|
model_name='xlm-roberta-large' |
|
) |
|
|
|
|
|
weights_path = model_dir / 'pytorch_model.bin' |
|
if not weights_path.exists(): |
|
raise FileNotFoundError(f"Model weights not found at {weights_path}") |
|
|
|
state_dict = torch.load(weights_path) |
|
model.load_state_dict(state_dict) |
|
logger.info("Model weights loaded successfully") |
|
|
|
|
|
logger.info("Loading XLM-RoBERTa tokenizer...") |
|
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') |
|
|
|
|
|
metadata_path = model_dir / 'metadata.json' |
|
if metadata_path.exists(): |
|
with open(metadata_path) as f: |
|
metadata = json.load(f) |
|
logger.info(f"Loaded checkpoint metadata: Epoch {metadata.get('epoch', 'unknown')}") |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
model.eval() |
|
|
|
return model, tokenizer, device |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
logger.error("\nPlease ensure that:") |
|
logger.error("1. You have trained the model first using train.py") |
|
logger.error("2. The model weights are saved in the correct location") |
|
logger.error("3. You have sufficient permissions to access the model files") |
|
return None, None, None |
|
|
|
def adjust_thresholds(thresholds): |
|
""" |
|
Adjust thresholds based on recommendations to reduce overflagging |
|
""" |
|
if not thresholds: |
|
return thresholds |
|
|
|
adjusted = thresholds.copy() |
|
|
|
for lang_id in adjusted: |
|
for category, recommended in DEFAULT_THRESHOLDS.items(): |
|
if category in adjusted[lang_id]: |
|
|
|
adjusted[lang_id][category] = max(adjusted[lang_id][category], recommended) |
|
|
|
return adjusted |
|
|
|
def analyze_unicode_ranges(text): |
|
"""Analyze text for characters in language-specific Unicode ranges""" |
|
scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()} |
|
|
|
for char in text: |
|
code = ord(char) |
|
for lang, ranges in UNICODE_RANGES.items(): |
|
for start, end in ranges: |
|
if start <= code <= end: |
|
scores[lang] += 1 |
|
|
|
return scores |
|
|
|
def analyze_tokenizer_stats(text, tokenizer): |
|
"""Analyze tokenizer statistics for language detection""" |
|
|
|
tokens = tokenizer.tokenize(text) |
|
|
|
|
|
scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()} |
|
|
|
|
|
for token in tokens: |
|
token = token.lower() |
|
|
|
if 'en' in token or '_en' in token: |
|
scores['en'] += 1 |
|
elif 'ru' in token or '_ru' in token: |
|
scores['ru'] += 1 |
|
elif 'tr' in token or '_tr' in token: |
|
scores['tr'] += 1 |
|
elif 'es' in token or '_es' in token: |
|
scores['es'] += 1 |
|
elif 'fr' in token or '_fr' in token: |
|
scores['fr'] += 1 |
|
elif 'it' in token or '_it' in token: |
|
scores['it'] += 1 |
|
elif 'pt' in token or '_pt' in token: |
|
scores['pt'] += 1 |
|
|
|
return scores |
|
|
|
def detect_language(text, tokenizer): |
|
""" |
|
Enhanced language detection using langdetect with multiple fallback methods: |
|
1. Primary: langdetect library |
|
2. Fallback 1: ASCII analysis for English |
|
3. Fallback 2: Unicode range analysis |
|
4. Fallback 3: Tokenizer statistics |
|
""" |
|
try: |
|
|
|
text = text.strip() |
|
|
|
|
|
if not text or not re.search(r'\w', text): |
|
return SUPPORTED_LANGUAGES['en'] |
|
|
|
|
|
try: |
|
detected_code = detect(text) |
|
|
|
lang_mapping = { |
|
'eng': 'en', |
|
'rus': 'ru', |
|
'tur': 'tr', |
|
'spa': 'es', |
|
'fra': 'fr', |
|
'ita': 'it', |
|
'por': 'pt' |
|
} |
|
detected_code = lang_mapping.get(detected_code, detected_code) |
|
|
|
if detected_code in SUPPORTED_LANGUAGES: |
|
return SUPPORTED_LANGUAGES[detected_code] |
|
except LangDetectException: |
|
pass |
|
|
|
|
|
if all(ord(c) < 128 for c in text): |
|
return SUPPORTED_LANGUAGES['en'] |
|
|
|
|
|
unicode_scores = analyze_unicode_ranges(text) |
|
tokenizer_scores = analyze_tokenizer_stats(text, tokenizer) |
|
|
|
|
|
final_scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()} |
|
for lang in SUPPORTED_LANGUAGES.keys(): |
|
final_scores[lang] = ( |
|
unicode_scores[lang] * 2 + |
|
tokenizer_scores[lang] |
|
) |
|
|
|
|
|
if any(score > 0 for score in final_scores.values()): |
|
detected_lang = max(final_scores.items(), key=lambda x: x[1])[0] |
|
return SUPPORTED_LANGUAGES[detected_lang] |
|
|
|
|
|
return SUPPORTED_LANGUAGES['en'] |
|
|
|
except Exception as e: |
|
logger.warning(f"Language detection failed ({str(e)}). Using English.") |
|
return SUPPORTED_LANGUAGES['en'] |
|
|
|
def predict_toxicity(text, model, tokenizer, device): |
|
"""Predict toxicity labels for a given text""" |
|
|
|
lang_id = detect_language(text, tokenizer) |
|
|
|
|
|
encoding = tokenizer( |
|
text, |
|
max_length=128, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
|
|
|
|
input_ids = encoding['input_ids'].to(device) |
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
predictions = outputs['probabilities'] |
|
|
|
|
|
probabilities = predictions[0].cpu().numpy() |
|
|
|
|
|
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
|
|
|
|
|
results = {} |
|
for label, prob in zip(labels, probabilities): |
|
threshold = DEFAULT_THRESHOLDS.get(label, 0.5) |
|
results[label] = { |
|
'probability': float(prob), |
|
'is_toxic': prob > threshold, |
|
'threshold': threshold |
|
} |
|
|
|
return results, lang_id |
|
|
|
def main(): |
|
|
|
print("Loading model...") |
|
model_path = 'weights/toxic_classifier_xlm-roberta-large/latest' |
|
model, tokenizer, device = load_model(model_path) |
|
|
|
if model is None or tokenizer is None: |
|
return |
|
|
|
while True: |
|
try: |
|
|
|
print("\nEnter text to analyze (or 'q' to quit):") |
|
try: |
|
if sys.platform == 'win32': |
|
|
|
text = sys.stdin.buffer.readline().decode('utf-8').strip() |
|
else: |
|
text = input().strip() |
|
except UnicodeDecodeError: |
|
|
|
if sys.platform == 'win32': |
|
text = sys.stdin.buffer.readline().decode('latin-1').strip() |
|
else: |
|
text = sys.stdin.buffer.readline().decode('latin-1').strip() |
|
|
|
if text.lower() == 'q': |
|
break |
|
|
|
if not text: |
|
print("Please enter some text to analyze.") |
|
continue |
|
|
|
|
|
print("\nAnalyzing text...") |
|
predictions, lang_id = predict_toxicity(text, model, tokenizer, device) |
|
|
|
|
|
lang_name = [k for k, v in SUPPORTED_LANGUAGES.items() if v == lang_id][0] |
|
|
|
|
|
print("\nResults:") |
|
print("-" * 50) |
|
print(f"Text: {text}") |
|
print(f"Detected Language: {lang_name}") |
|
print("\nToxicity Analysis:") |
|
|
|
any_toxic = False |
|
for label, result in predictions.items(): |
|
if result['is_toxic']: |
|
any_toxic = True |
|
print(f"- {label}: {result['probability']:.2%} (threshold: {result['threshold']:.2%}) ⚠️") |
|
|
|
|
|
print("\nOther categories:") |
|
for label, result in predictions.items(): |
|
if not result['is_toxic']: |
|
print(f"- {label}: {result['probability']:.2%} (threshold: {result['threshold']:.2%}) ✓") |
|
|
|
|
|
print("\nOverall Assessment:") |
|
if any_toxic: |
|
print("⚠️ This text contains toxic content") |
|
else: |
|
print("✅ This text appears to be non-toxic") |
|
|
|
except Exception as e: |
|
logger.error(f"Unexpected error: {str(e)}") |
|
print("\nAn unexpected error occurred. Please try again.") |
|
continue |
|
|
|
if __name__ == "__main__": |
|
main() |