|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM |
|
from huggingface_hub import login |
|
import os |
|
import time |
|
|
|
|
|
class TinyTransformer(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_dim) |
|
self.pos_encoding = nn.Parameter(torch.zeros(1, 512, embed_dim)) |
|
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True) |
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
self.fc = nn.Linear(embed_dim, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :] |
|
x = self.transformer(x) |
|
x = x.mean(dim=1) |
|
x = self.fc(x) |
|
return self.sigmoid(x) |
|
|
|
class TinyTransformerConfig(PretrainedConfig): |
|
model_type = "tiny_transformer" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=30522, |
|
embed_dim=64, |
|
num_heads=2, |
|
ff_dim=128, |
|
num_layers=4, |
|
max_position_embeddings=512, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.ff_dim = ff_dim |
|
self.num_layers = num_layers |
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
class TinyTransformerForSequenceClassification(PreTrainedModel): |
|
config_class = TinyTransformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = 1 |
|
self.transformer = TinyTransformer( |
|
config.vocab_size, |
|
config.embed_dim, |
|
config.num_heads, |
|
config.ff_dim, |
|
config.num_layers |
|
) |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
outputs = self.transformer(input_ids) |
|
return {"logits": outputs} |
|
|
|
|
|
@st.cache_resource |
|
def load_models_and_tokenizers(hf_token): |
|
login(token=hf_token) |
|
device = torch.device("cpu") |
|
|
|
models = {} |
|
tokenizers = {} |
|
|
|
|
|
config = TinyTransformerConfig.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token) |
|
models["Tiny-toxic-detector"] = TinyTransformerForSequenceClassification.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", config=config, use_auth_token=hf_token).to(device) |
|
tokenizers["Tiny-toxic-detector"] = AutoTokenizer.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token) |
|
|
|
|
|
model_configs = [ |
|
("s-nlp/roberta_toxicity_classifier", AutoModelForSequenceClassification, "s-nlp/roberta_toxicity_classifier"), |
|
("martin-ha/toxic-comment-model", AutoModelForSequenceClassification, "martin-ha/toxic-comment-model"), |
|
("lmsys/toxicchat-t5-large-v1.0", AutoModelForSeq2SeqLM, "t5-large") |
|
] |
|
|
|
for model_name, model_class, tokenizer_name in model_configs: |
|
models[model_name] = model_class.from_pretrained(model_name, use_auth_token=hf_token).to(device) |
|
tokenizers[model_name] = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=hf_token) |
|
|
|
return models, tokenizers, device |
|
|
|
|
|
def predict_toxicity(text, model, tokenizer, device, model_name): |
|
start_time = time.time() |
|
|
|
if model_name == "lmsys/toxicchat-t5-large-v1.0": |
|
prefix = "ToxicChat: " |
|
inputs = tokenizer.encode(prefix + text, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(inputs, max_new_tokens=5) |
|
|
|
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower() |
|
|
|
|
|
print(f"Raw model output: {prediction}") |
|
|
|
|
|
prediction = "Toxic" if prediction in ["positive", "pos", "toxic", "yes"] else "Not Toxic" |
|
else: |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding="max_length").to(device) |
|
|
|
if "token_type_ids" in inputs: |
|
del inputs["token_type_ids"] |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
if model_name == "Tiny-toxic-detector": |
|
logits = outputs["logits"].squeeze() |
|
prediction = "Toxic" if logits > 0.5 else "Not Toxic" |
|
else: |
|
logits = outputs.logits.squeeze() |
|
prediction = "Toxic" if logits[1] > logits[0] else "Not Toxic" |
|
|
|
end_time = time.time() |
|
inference_time = end_time - start_time |
|
|
|
return prediction, inference_time |
|
|
|
def main(): |
|
st.set_page_config(page_title="Toxicity Detector Model Comparison", layout="wide") |
|
st.title("Toxicity Detector Model Comparison") |
|
|
|
|
|
st.markdown(""" |
|
### How It Works |
|
This application compares various toxicity detection models to classify whether a given text is toxic or not. The models being compared include: |
|
|
|
- [**Tiny-Toxic-Detector**](https://huggingface.co/AssistantsLab/Tiny-Toxic-Detector): A 2M parameter model with a new architecture released by [AssistantsLab](https://huggingface.co/AssistantsLab). |
|
- [**RoBERTa-Toxicity-Classifier**](s-nlp/roberta_toxicity_classifier): A 124M parameter RoBERTa-based model. |
|
- [**Toxic-Comment-Model**](https://huggingface.co/martin-ha/toxic-comment-model): A 67M parameter DistilBERT-based model. |
|
- [**ToxicChat-T5**](https://huggingface.co/lmsys/toxicchat-t5-large-v1.0): A 738M parameter T5-based model. |
|
|
|
Simply enter the text you want to classify, and the app will provide the predictions from each model, along with the inference time. |
|
Please note these models are (mostly) English-only. |
|
""") |
|
|
|
|
|
hf_token = os.getenv('AT') |
|
models, tokenizers, device = load_models_and_tokenizers(hf_token) |
|
|
|
|
|
model_names = sorted(models.keys(), key=lambda x: x == "Tiny-toxic-detector") |
|
|
|
|
|
text = st.text_area("Enter text to classify:", height=150) |
|
|
|
if st.button("Classify"): |
|
if text: |
|
progress_bar = st.progress(0) |
|
results = [] |
|
|
|
for i, model_name in enumerate(model_names): |
|
with st.spinner(f"Classifying with {model_name}..."): |
|
prediction, inference_time = predict_toxicity(text, models[model_name], tokenizers[model_name], device, model_name) |
|
results.append((model_name, prediction, inference_time)) |
|
progress_bar.progress((i + 1) / len(model_names)) |
|
|
|
st.success("Classification complete!") |
|
progress_bar.empty() |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
for i, (model_name, prediction, inference_time) in enumerate(results): |
|
with [col1, col2, col3][i % 3]: |
|
st.subheader(model_name) |
|
st.write(f"Prediction: {prediction}") |
|
st.write(f"Inference Time: {inference_time:.4f}s") |
|
st.write("---") |
|
else: |
|
st.warning("Please enter some text to classify.") |
|
|
|
if __name__ == "__main__": |
|
main() |