File size: 7,130 Bytes
d187b57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import onnxruntime as ort
from transformers import XLMRobertaTokenizer
import numpy as np
import os
class OptimizedToxicityClassifier:
"""High-performance toxicity classifier for production"""
def __init__(self, onnx_path=None, pytorch_path=None, device='cuda'):
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
# Language mapping
self.lang_map = {
'en': 0, 'ru': 1, 'tr': 2, 'es': 3,
'fr': 4, 'it': 5, 'pt': 6
}
# Label names
self.label_names = [
'toxic', 'severe_toxic', 'obscene',
'threat', 'insult', 'identity_hate'
]
# Load ONNX model if path provided
if onnx_path and os.path.exists(onnx_path):
# Use ONNX Runtime for inference
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] \
if device == 'cuda' and 'CUDAExecutionProvider' in ort.get_available_providers() \
else ['CPUExecutionProvider']
self.session = ort.InferenceSession(onnx_path, providers=providers)
self.use_onnx = True
print(f"Loaded ONNX model from {onnx_path}")
# Fall back to PyTorch if ONNX not available
elif pytorch_path:
from model.language_aware_transformer import LanguageAwareTransformer
# Handle directory structure with checkpoint folders and 'latest' symlink
if os.path.isdir(pytorch_path):
# Check if there's a 'latest' symlink
latest_path = os.path.join(pytorch_path, 'latest')
if os.path.islink(latest_path) and os.path.exists(latest_path):
checkpoint_dir = latest_path
else:
# If no 'latest' symlink, look for checkpoint dirs and use the most recent one
checkpoint_dirs = [d for d in os.listdir(pytorch_path) if d.startswith('checkpoint_epoch')]
if checkpoint_dirs:
checkpoint_dirs.sort() # Sort to get the latest by name
checkpoint_dir = os.path.join(pytorch_path, checkpoint_dirs[-1])
else:
raise ValueError(f"No checkpoint directories found in {pytorch_path}")
# Look for PyTorch model files in the checkpoint directory
model_file = None
potential_files = ['pytorch_model.bin', 'model.pt', 'model.pth']
for file in potential_files:
candidate = os.path.join(checkpoint_dir, file)
if os.path.exists(candidate):
model_file = candidate
break
if not model_file:
raise FileNotFoundError(f"No model file found in {checkpoint_dir}")
print(f"Using model from checkpoint: {checkpoint_dir}")
model_path = model_file
else:
# If pytorch_path is a direct file path
model_path = pytorch_path
self.model = LanguageAwareTransformer(num_labels=6)
self.model.load_state_dict(torch.load(model_path, map_location=device))
self.model.to(device)
self.model.eval()
self.use_onnx = False
self.device = device
print(f"Loaded PyTorch model from {model_path}")
else:
raise ValueError("Either onnx_path or pytorch_path must be provided")
def predict(self, texts, langs=None, batch_size=8):
"""
Predict toxicity for a list of texts
Args:
texts: List of text strings
langs: List of language codes (e.g., 'en', 'fr')
batch_size: Batch size for processing
Returns:
List of dictionaries with toxicity predictions
"""
results = []
# Auto-detect or default language if not provided
if langs is None:
langs = ['en'] * len(texts)
# Convert language codes to IDs
lang_ids = [self.lang_map.get(lang, 0) for lang in langs]
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_langs = lang_ids[i:i+batch_size]
# Tokenize
inputs = self.tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=128,
return_tensors='pt'
)
# Get predictions
if self.use_onnx:
# ONNX inference
ort_inputs = {
'input_ids': inputs['input_ids'].numpy(),
'attention_mask': inputs['attention_mask'].numpy(),
'lang_ids': np.array(batch_langs, dtype=np.int64)
}
ort_outputs = self.session.run(None, ort_inputs)
probabilities = 1 / (1 + np.exp(-ort_outputs[0])) # sigmoid
else:
# PyTorch inference
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
lang_tensor = torch.tensor(batch_langs, dtype=torch.long, device=self.device)
outputs = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
lang_ids=lang_tensor,
mode='inference'
)
probabilities = outputs['probabilities'].cpu().numpy()
# Format results
for j, (text, lang, probs) in enumerate(zip(batch_texts, langs[i:i+batch_size], probabilities)):
# Apply optimal thresholds per language
lang_thresholds = { # Increased by ~20% from original values
'default': [0.60, 0.54, 0.60, 0.48, 0.60, 0.50] # Increased by ~20% from original values
# mapping [toxic, severe_toxic, obscene, threat, insult, identity_hate]
}
thresholds = lang_thresholds.get(lang, lang_thresholds['default'])
is_toxic = (probs >= np.array(thresholds)).astype(bool)
result = {
'text': text,
'language': lang,
'probabilities': {
label: float(prob) for label, prob in zip(self.label_names, probs)
},
'is_toxic': bool(is_toxic.any()),
'toxic_categories': [
self.label_names[k] for k in range(len(is_toxic)) if is_toxic[k]
]
}
results.append(result)
return results |