File size: 14,670 Bytes
d187b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
import torch
from model.language_aware_transformer import LanguageAwareTransformer
from transformers import XLMRobertaTokenizer
import os
import re
import json
from pathlib import Path
import logging
from langdetect import detect, DetectorFactory
from langdetect.lang_detect_exception import LangDetectException
import sys
import locale
import io

# Force UTF-8 encoding for stdin/stdout
if sys.platform == 'win32':
    # Windows-specific handling
    import msvcrt
    sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
    # Set console to UTF-8 mode
    os.system('chcp 65001')
else:
    # Unix-like systems
    if sys.stdout.encoding != 'utf-8':
        sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
    if sys.stdin.encoding != 'utf-8':
        sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')

# Set up logging with UTF-8 encoding
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# Ensure reproducibility with langdetect
DetectorFactory.seed = 0

SUPPORTED_LANGUAGES = {
    'en': 0, 'ru': 1, 'tr': 2, 'es': 3,
    'fr': 4, 'it': 5, 'pt': 6
}

# Default thresholds optimized on validation set
DEFAULT_THRESHOLDS = {
    'toxic': 0.80,         # Optimized for general toxicity
    'severe_toxic': 0.45,  # Lower to catch serious cases
    'obscene': 0.48,      # Balanced for precision/recall
    'threat': 0.42,       # Lower to catch potential threats
    'insult': 0.70,       # Balanced for common cases
    'identity_hate': 0.43  # Lower to catch hate speech
}

# Unicode ranges for different scripts
UNICODE_RANGES = {
    'ru': [
        (0x0400, 0x04FF),  # Cyrillic
        (0x0500, 0x052F),  # Cyrillic Supplement
    ],
    'tr': [
        (0x011E, 0x011F),  # Ğ ğ
        (0x0130, 0x0131),  # İ ı
        (0x015E, 0x015F),  # Ş ş
    ],
    'es': [
        (0x00C1, 0x00C1),  # Á
        (0x00C9, 0x00C9),  # É
        (0x00CD, 0x00CD),  # Í
        (0x00D1, 0x00D1),  # Ñ
        (0x00D3, 0x00D3),  # Ó
        (0x00DA, 0x00DA),  # Ú
        (0x00DC, 0x00DC),  # Ü
    ],
    'fr': [
        (0x00C0, 0x00C6),  # À-Æ
        (0x00C8, 0x00CB),  # È-Ë
        (0x00CC, 0x00CF),  # Ì-Ï
        (0x00D2, 0x00D6),  # Ò-Ö
        (0x0152, 0x0153),  # Œ œ
    ],
    'it': [
        (0x00C0, 0x00C0),  # À
        (0x00C8, 0x00C8),  # È
        (0x00C9, 0x00C9),  # É
        (0x00CC, 0x00CC),  # Ì
        (0x00D2, 0x00D2),  # Ò
        (0x00D9, 0x00D9),  # Ù
    ],
    'pt': [
        (0x00C0, 0x00C3),  # À-Ã
        (0x00C7, 0x00C7),  # Ç
        (0x00C9, 0x00CA),  # É-Ê
        (0x00D3, 0x00D5),  # Ó-Õ
    ]
}

def load_model(model_path):
    """Load the trained model and tokenizer"""
    try:
        # Convert to absolute Path object
        model_dir = Path(model_path).absolute()
        
        if model_dir.is_dir():
            # Check for 'latest' symlink first
            latest_link = model_dir / 'latest'
            if latest_link.exists() and latest_link.is_symlink():
                # Get the target of the symlink
                target = latest_link.readlink()
                # If target is absolute, use it directly
                if target.is_absolute():
                    model_dir = target
                else:
                    # If target is relative, resolve it relative to the symlink's directory
                    model_dir = (latest_link.parent / target).resolve()
                logger.info(f"Using latest checkpoint: {model_dir}")
            else:
                # Find most recent checkpoint
                checkpoints = sorted([
                    d for d in model_dir.iterdir() 
                    if d.is_dir() and d.name.startswith('checkpoint_epoch')
                ])
                if checkpoints:
                    model_dir = checkpoints[-1]
                    logger.info(f"Using most recent checkpoint: {model_dir}")
                else:
                    logger.info("No checkpoints found, using base directory")
        
        logger.info(f"Loading model from: {model_dir}")
        
        # Verify the directory exists
        if not model_dir.exists():
            raise FileNotFoundError(f"Model directory not found: {model_dir}")
        
        # Initialize the custom model architecture
        model = LanguageAwareTransformer(
            num_labels=6,
            hidden_size=1024,
            num_attention_heads=16,
            model_name='xlm-roberta-large'
        )
        
        # Load the trained weights
        weights_path = model_dir / 'pytorch_model.bin'
        if not weights_path.exists():
            raise FileNotFoundError(f"Model weights not found at {weights_path}")
            
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)
        logger.info("Model weights loaded successfully")
        
        # Load base XLM-RoBERTa tokenizer directly
        logger.info("Loading XLM-RoBERTa tokenizer...")
        tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
        
        # Load training metadata if available
        metadata_path = model_dir / 'metadata.json'
        if metadata_path.exists():
            with open(metadata_path) as f:
                metadata = json.load(f)
            logger.info(f"Loaded checkpoint metadata: Epoch {metadata.get('epoch', 'unknown')}")
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model.eval()
        
        return model, tokenizer, device
        
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        logger.error("\nPlease ensure that:")
        logger.error("1. You have trained the model first using train.py")
        logger.error("2. The model weights are saved in the correct location")
        logger.error("3. You have sufficient permissions to access the model files")
        return None, None, None

def adjust_thresholds(thresholds):
    """
    Adjust thresholds based on recommendations to reduce overflagging
    """
    if not thresholds:
        return thresholds
        
    adjusted = thresholds.copy()
    # Adjust thresholds for each language
    for lang_id in adjusted:
        for category, recommended in DEFAULT_THRESHOLDS.items():
            if category in adjusted[lang_id]:
                # Only increase threshold if recommended is higher
                adjusted[lang_id][category] = max(adjusted[lang_id][category], recommended)
    
    return adjusted

def analyze_unicode_ranges(text):
    """Analyze text for characters in language-specific Unicode ranges"""
    scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()}
    
    for char in text:
        code = ord(char)
        for lang, ranges in UNICODE_RANGES.items():
            for start, end in ranges:
                if start <= code <= end:
                    scores[lang] += 1
    
    return scores

def analyze_tokenizer_stats(text, tokenizer):
    """Analyze tokenizer statistics for language detection"""
    # Get tokenizer output
    tokens = tokenizer.tokenize(text)
    
    # Count language-specific token patterns
    scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()}
    
    # Analyze token patterns
    for token in tokens:
        token = token.lower()
        # Check for language-specific subwords
        if 'en' in token or '_en' in token:
            scores['en'] += 1
        elif 'ru' in token or '_ru' in token:
            scores['ru'] += 1
        elif 'tr' in token or '_tr' in token:
            scores['tr'] += 1
        elif 'es' in token or '_es' in token:
            scores['es'] += 1
        elif 'fr' in token or '_fr' in token:
            scores['fr'] += 1
        elif 'it' in token or '_it' in token:
            scores['it'] += 1
        elif 'pt' in token or '_pt' in token:
            scores['pt'] += 1
    
    return scores

def detect_language(text, tokenizer):
    """
    Enhanced language detection using langdetect with multiple fallback methods:
    1. Primary: langdetect library
    2. Fallback 1: ASCII analysis for English
    3. Fallback 2: Unicode range analysis
    4. Fallback 3: Tokenizer statistics
    """
    try:
        # Clean text
        text = text.strip()
        
        # If empty or just punctuation, default to English
        if not text or not re.search(r'\w', text):
            return SUPPORTED_LANGUAGES['en']
            
        # Primary method: Use langdetect
        try:
            detected_code = detect(text)
            # Map some common language codes that might differ
            lang_mapping = {
                'eng': 'en',
                'rus': 'ru',
                'tur': 'tr',
                'spa': 'es',
                'fra': 'fr',
                'ita': 'it',
                'por': 'pt'
            }
            detected_code = lang_mapping.get(detected_code, detected_code)
            
            if detected_code in SUPPORTED_LANGUAGES:
                return SUPPORTED_LANGUAGES[detected_code]
        except LangDetectException:
            pass  # Continue to fallback methods
            
        # Fallback 1: If text is ASCII only, likely English
        if all(ord(c) < 128 for c in text):
            return SUPPORTED_LANGUAGES['en']
        
        # Fallback 2 & 3: Combine Unicode analysis and tokenizer statistics
        unicode_scores = analyze_unicode_ranges(text)
        tokenizer_scores = analyze_tokenizer_stats(text, tokenizer)
        
        # Combine scores with weights
        final_scores = {lang: 0 for lang in SUPPORTED_LANGUAGES.keys()}
        for lang in SUPPORTED_LANGUAGES.keys():
            final_scores[lang] = (
                unicode_scores[lang] * 2 +  # Unicode ranges have higher weight
                tokenizer_scores[lang]
            )
        
        # Get language with highest score
        if any(score > 0 for score in final_scores.values()):
            detected_lang = max(final_scores.items(), key=lambda x: x[1])[0]
            return SUPPORTED_LANGUAGES[detected_lang]
        
        # Default to English if no clear match
        return SUPPORTED_LANGUAGES['en']
        
    except Exception as e:
        logger.warning(f"Language detection failed ({str(e)}). Using English.")
        return SUPPORTED_LANGUAGES['en']

def predict_toxicity(text, model, tokenizer, device):
    """Predict toxicity labels for a given text"""
    # Detect language
    lang_id = detect_language(text, tokenizer)
    
    # Tokenize text
    encoding = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = outputs['probabilities']
    
    # Convert to probabilities
    probabilities = predictions[0].cpu().numpy()
    
    # Labels for toxicity types
    labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    
    # Create results dictionary using optimized thresholds
    results = {}
    for label, prob in zip(labels, probabilities):
        threshold = DEFAULT_THRESHOLDS.get(label, 0.5)  # Use optimized defaults
        results[label] = {
            'probability': float(prob),
            'is_toxic': prob > threshold,
            'threshold': threshold
        }
    
    return results, lang_id

def main():
    # Load model
    print("Loading model...")
    model_path = 'weights/toxic_classifier_xlm-roberta-large/latest'
    model, tokenizer, device = load_model(model_path)
    
    if model is None or tokenizer is None:
        return
    
    while True:
        try:
            # Get input text with proper Unicode handling
            print("\nEnter text to analyze (or 'q' to quit):")
            try:
                if sys.platform == 'win32':
                    # Windows-specific input handling
                    text = sys.stdin.buffer.readline().decode('utf-8').strip()
                else:
                    text = input().strip()
            except UnicodeDecodeError:
                # Fallback to latin-1 if UTF-8 fails
                if sys.platform == 'win32':
                    text = sys.stdin.buffer.readline().decode('latin-1').strip()
                else:
                    text = sys.stdin.buffer.readline().decode('latin-1').strip()
            
            if text.lower() == 'q':
                break
            
            if not text:
                print("Please enter some text to analyze.")
                continue
            
            # Make prediction
            print("\nAnalyzing text...")
            predictions, lang_id = predict_toxicity(text, model, tokenizer, device)
            
            # Get language name
            lang_name = [k for k, v in SUPPORTED_LANGUAGES.items() if v == lang_id][0]
            
            # Print results
            print("\nResults:")
            print("-" * 50)
            print(f"Text: {text}")
            print(f"Detected Language: {lang_name}")
            print("\nToxicity Analysis:")
            
            any_toxic = False
            for label, result in predictions.items():
                if result['is_toxic']:
                    any_toxic = True
                    print(f"- {label}: {result['probability']:.2%} (threshold: {result['threshold']:.2%}) ⚠️")
            
            # Print non-toxic results with lower emphasis
            print("\nOther categories:")
            for label, result in predictions.items():
                if not result['is_toxic']:
                    print(f"- {label}: {result['probability']:.2%} (threshold: {result['threshold']:.2%}) ✓")
            
            # Overall assessment
            print("\nOverall Assessment:")
            if any_toxic:
                print("⚠️  This text contains toxic content")
            else:
                print("✅  This text appears to be non-toxic")
                
        except Exception as e:
            logger.error(f"Unexpected error: {str(e)}")
            print("\nAn unexpected error occurred. Please try again.")
            continue

if __name__ == "__main__":
    main()