File size: 7,130 Bytes
d187b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import onnxruntime as ort
from transformers import XLMRobertaTokenizer
import numpy as np
import os

class OptimizedToxicityClassifier:
    """High-performance toxicity classifier for production"""
    
    def __init__(self, onnx_path=None, pytorch_path=None, device='cuda'):
        self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
        
        # Language mapping
        self.lang_map = {
            'en': 0, 'ru': 1, 'tr': 2, 'es': 3, 
            'fr': 4, 'it': 5, 'pt': 6
        }
        
        # Label names
        self.label_names = [
            'toxic', 'severe_toxic', 'obscene', 
            'threat', 'insult', 'identity_hate'
        ]
        
        # Load ONNX model if path provided
        if onnx_path and os.path.exists(onnx_path):
            # Use ONNX Runtime for inference
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] \
                if device == 'cuda' and 'CUDAExecutionProvider' in ort.get_available_providers() \
                else ['CPUExecutionProvider']
                
            self.session = ort.InferenceSession(onnx_path, providers=providers)
            self.use_onnx = True
            print(f"Loaded ONNX model from {onnx_path}")
        # Fall back to PyTorch if ONNX not available
        elif pytorch_path:
            from model.language_aware_transformer import LanguageAwareTransformer
            
            # Handle directory structure with checkpoint folders and 'latest' symlink
            if os.path.isdir(pytorch_path):
                # Check if there's a 'latest' symlink
                latest_path = os.path.join(pytorch_path, 'latest')
                if os.path.islink(latest_path) and os.path.exists(latest_path):
                    checkpoint_dir = latest_path
                else:
                    # If no 'latest' symlink, look for checkpoint dirs and use the most recent one
                    checkpoint_dirs = [d for d in os.listdir(pytorch_path) if d.startswith('checkpoint_epoch')]
                    if checkpoint_dirs:
                        checkpoint_dirs.sort()  # Sort to get the latest by name
                        checkpoint_dir = os.path.join(pytorch_path, checkpoint_dirs[-1])
                    else:
                        raise ValueError(f"No checkpoint directories found in {pytorch_path}")
                
                # Look for PyTorch model files in the checkpoint directory
                model_file = None
                potential_files = ['pytorch_model.bin', 'model.pt', 'model.pth']
                for file in potential_files:
                    candidate = os.path.join(checkpoint_dir, file)
                    if os.path.exists(candidate):
                        model_file = candidate
                        break
                
                if not model_file:
                    raise FileNotFoundError(f"No model file found in {checkpoint_dir}")
                
                print(f"Using model from checkpoint: {checkpoint_dir}")
                model_path = model_file
            else:
                # If pytorch_path is a direct file path
                model_path = pytorch_path
                
            self.model = LanguageAwareTransformer(num_labels=6)
            self.model.load_state_dict(torch.load(model_path, map_location=device))
            self.model.to(device)
            self.model.eval()
            self.use_onnx = False
            self.device = device
            print(f"Loaded PyTorch model from {model_path}")
        else:
            raise ValueError("Either onnx_path or pytorch_path must be provided")
    
    def predict(self, texts, langs=None, batch_size=8):
        """
        Predict toxicity for a list of texts
        
        Args:
            texts: List of text strings
            langs: List of language codes (e.g., 'en', 'fr')
            batch_size: Batch size for processing
            
        Returns:
            List of dictionaries with toxicity predictions
        """
        results = []
        
        # Auto-detect or default language if not provided
        if langs is None:
            langs = ['en'] * len(texts)
        
        # Convert language codes to IDs
        lang_ids = [self.lang_map.get(lang, 0) for lang in langs]
        
        # Process in batches
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_langs = lang_ids[i:i+batch_size]
            
            # Tokenize
            inputs = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=128,
                return_tensors='pt'
            )
            
            # Get predictions
            if self.use_onnx:
                # ONNX inference
                ort_inputs = {
                    'input_ids': inputs['input_ids'].numpy(),
                    'attention_mask': inputs['attention_mask'].numpy(),
                    'lang_ids': np.array(batch_langs, dtype=np.int64)
                }
                ort_outputs = self.session.run(None, ort_inputs)
                probabilities = 1 / (1 + np.exp(-ort_outputs[0]))  # sigmoid
            else:
                # PyTorch inference
                with torch.no_grad():
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    lang_tensor = torch.tensor(batch_langs, dtype=torch.long, device=self.device)
                    outputs = self.model(
                        input_ids=inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        lang_ids=lang_tensor,
                        mode='inference'
                    )
                    probabilities = outputs['probabilities'].cpu().numpy()
            
            # Format results
            for j, (text, lang, probs) in enumerate(zip(batch_texts, langs[i:i+batch_size], probabilities)):
                # Apply optimal thresholds per language
                lang_thresholds = {  # Increased by ~20% from original values
                    'default': [0.60, 0.54, 0.60, 0.48, 0.60, 0.50]  # Increased by ~20% from original values
                    # mapping [toxic, severe_toxic, obscene, threat, insult, identity_hate]
                }
                
                
                thresholds = lang_thresholds.get(lang, lang_thresholds['default'])
                is_toxic = (probs >= np.array(thresholds)).astype(bool)
                
                result = {
                    'text': text,
                    'language': lang,
                    'probabilities': {
                        label: float(prob) for label, prob in zip(self.label_names, probs)
                    },
                    'is_toxic': bool(is_toxic.any()),
                    'toxic_categories': [
                        self.label_names[k] for k in range(len(is_toxic)) if is_toxic[k]
                    ]
                }
                results.append(result)
        
        return results