Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM | |
from huggingface_hub import login | |
import os | |
import time | |
# Model Architecture | |
class TinyTransformer(nn.Module): | |
def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_dim) | |
self.pos_encoding = nn.Parameter(torch.zeros(1, 512, embed_dim)) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True) | |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.fc = nn.Linear(embed_dim, 1) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :] | |
x = self.transformer(x) | |
x = x.mean(dim=1) # Global average pooling | |
x = self.fc(x) | |
return self.sigmoid(x) | |
class TinyTransformerConfig(PretrainedConfig): | |
model_type = "tiny_transformer" | |
def __init__( | |
self, | |
vocab_size=30522, | |
embed_dim=64, | |
num_heads=2, | |
ff_dim=128, | |
num_layers=4, | |
max_position_embeddings=512, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.vocab_size = vocab_size | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.ff_dim = ff_dim | |
self.num_layers = num_layers | |
self.max_position_embeddings = max_position_embeddings | |
class TinyTransformerForSequenceClassification(PreTrainedModel): | |
config_class = TinyTransformerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = 1 | |
self.transformer = TinyTransformer( | |
config.vocab_size, | |
config.embed_dim, | |
config.num_heads, | |
config.ff_dim, | |
config.num_layers | |
) | |
def forward(self, input_ids, attention_mask=None): | |
outputs = self.transformer(input_ids) | |
return {"logits": outputs} | |
# Load models and tokenizers | |
def load_models_and_tokenizers(hf_token): | |
login(token=hf_token) | |
device = torch.device("cpu") # forcing CPU as overhead of inference on GPU slows down the inference | |
models = {} | |
tokenizers = {} | |
# Load Tiny-toxic-detector | |
config = TinyTransformerConfig.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token) | |
models["Tiny-toxic-detector"] = TinyTransformerForSequenceClassification.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", config=config, use_auth_token=hf_token).to(device) | |
tokenizers["Tiny-toxic-detector"] = AutoTokenizer.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token) | |
# Load other models | |
model_configs = [ | |
("unitary/toxic-bert", AutoModelForSequenceClassification, "unitary/toxic-bert"), | |
("s-nlp/roberta_toxicity_classifier", AutoModelForSequenceClassification, "s-nlp/roberta_toxicity_classifier"), | |
("martin-ha/toxic-comment-model", AutoModelForSequenceClassification, "martin-ha/toxic-comment-model"), | |
("lmsys/toxicchat-t5-large-v1.0", AutoModelForSeq2SeqLM, "t5-large") | |
] | |
for model_name, model_class, tokenizer_name in model_configs: | |
models[model_name] = model_class.from_pretrained(model_name, use_auth_token=hf_token).to(device) | |
tokenizers[model_name] = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=hf_token) | |
return models, tokenizers, device | |
# Prediction function | |
def predict_toxicity(text, model, tokenizer, device, model_name): | |
start_time = time.time() | |
if model_name == "lmsys/toxicchat-t5-large-v1.0": | |
prefix = "ToxicChat: " | |
inputs = tokenizer.encode(prefix + text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.generate(inputs, max_new_tokens=5) | |
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower() | |
prediction = "Toxic" if prediction == "positive" else "Not Toxic" | |
else: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding="max_length").to(device) | |
if "token_type_ids" in inputs: | |
del inputs["token_type_ids"] | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
if model_name == "Tiny-toxic-detector": | |
logits = outputs["logits"].squeeze() | |
prediction = "Toxic" if logits > 0.5 else "Not Toxic" | |
else: | |
logits = outputs.logits.squeeze() | |
prediction = "Toxic" if logits[1] > logits[0] else "Not Toxic" | |
end_time = time.time() | |
inference_time = end_time - start_time | |
return prediction, inference_time | |
def main(): | |
st.set_page_config(page_title="Multi-Model Toxicity Detector", layout="wide") | |
st.title("Multi-Model Toxicity Detector") | |
# Load models | |
hf_token = os.getenv('AT') | |
models, tokenizers, device = load_models_and_tokenizers(hf_token) | |
# Reorder the models dictionary so that "Tiny-toxic-detector" is last | |
model_names = sorted(models.keys(), key=lambda x: x == "Tiny-toxic-detector") | |
# User input | |
text = st.text_area("Enter text to classify:", height=150) | |
if st.button("Classify"): | |
if text: | |
progress_bar = st.progress(0) | |
results = [] | |
for i, model_name in enumerate(model_names): | |
with st.spinner(f"Classifying with {model_name}..."): | |
prediction, inference_time = predict_toxicity(text, models[model_name], tokenizers[model_name], device, model_name) | |
results.append((model_name, prediction, inference_time)) | |
progress_bar.progress((i + 1) / len(model_names)) | |
st.success("Classification complete!") | |
progress_bar.empty() | |
# Display results in a grid | |
col1, col2, col3 = st.columns(3) | |
for i, (model_name, prediction, inference_time) in enumerate(results): | |
with [col1, col2, col3][i % 3]: | |
st.subheader(model_name) | |
st.write(f"Prediction: {prediction}") | |
st.write(f"Inference Time: {inference_time:.4f}s") | |
st.write("---") | |
else: | |
st.warning("Please enter some text to classify.") | |
if __name__ == "__main__": | |
main() |