|
--- |
|
datasets: |
|
- seanius/toxic-or-neutral-text-labelled |
|
language: |
|
- en |
|
library_name: transformers |
|
base_model: distilbert/distilbert-base-uncased |
|
--- |
|
ONNX model - a fine tuned version of DistilBERT which can be used to classify text as one of: |
|
- neutral, offensive_language, harmful_behaviour, hate_speech |
|
|
|
The model was trained using the [csfy tool](https://github.com/mrseanryan/csfy) and the dataset [seanius/toxic-or-neutral-text-labelled](https://huggingface.co/datasets/seanius/toxic-or-neutral-text-labelled) |
|
|
|
The base model is required (distilbert-base-uncased) |
|
|
|
For an example of how to run the model, see below - or see the [csfy tool](https://github.com/mrseanryan/csfy). |
|
|
|
The output is a number indicating the class - it is decoded via the label_mapping.json file. |
|
|
|
# Usage |
|
|
|
```python |
|
# Loading the label mappings |
|
import json |
|
def load_label_mappings(): |
|
with open("./label_mapping.json", encoding="utf-8") as f: |
|
data = json.load(f) |
|
return data['labels'] |
|
|
|
label_mappings = load_label_mappings() |
|
|
|
# Loading the model |
|
import onnxruntime as ort |
|
from transformers import DistilBertTokenizer |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
ort_session = ort.InferenceSession("./toxic-or-neutral-text-labelled.onnx") |
|
|
|
# Predicting label for given text |
|
def predict_via_onnx(text, ort_session, tokenizer, label_mappings): |
|
model_expected_input_shape = ort_session.get_inputs()[0].shape |
|
print("Model expects input shape:", model_expected_input_shape) |
|
inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=model_expected_input_shape[1]) |
|
print("input shape", inputs['input_ids'].shape) |
|
|
|
input_ids = inputs['input_ids'] |
|
if input_ids.ndim == 1: |
|
input_ids = input_ids[np.newaxis, :] |
|
ort_inputs = {ort_session.get_inputs()[0].name: input_ids} |
|
|
|
ort_inputs['input_ids'] = ort_inputs['input_ids'].astype(np.int64) |
|
|
|
ort_outputs = ort_session.run(None, ort_inputs) |
|
predictions = np.argmax(ort_outputs, axis=-1) |
|
|
|
predicted_label = label_mappings[predictions.item()] |
|
return predicted_label |
|
|
|
predicted_label = predict_via_onnx("How do I get to the beach?", ort_session, tokenizer, label_mappings) |
|
print(predicted_label) |
|
``` |
|
|
|
--- |
|
license: mit |
|
--- |