import streamlit as st import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM from huggingface_hub import login import os import time # Model Architecture class TinyTransformer(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.pos_encoding = nn.Parameter(torch.zeros(1, 512, embed_dim)) encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.fc = nn.Linear(embed_dim, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :] x = self.transformer(x) x = x.mean(dim=1) # Global average pooling x = self.fc(x) return self.sigmoid(x) class TinyTransformerConfig(PretrainedConfig): model_type = "tiny_transformer" def __init__( self, vocab_size=30522, embed_dim=64, num_heads=2, ff_dim=128, num_layers=4, max_position_embeddings=512, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim self.num_heads = num_heads self.ff_dim = ff_dim self.num_layers = num_layers self.max_position_embeddings = max_position_embeddings class TinyTransformerForSequenceClassification(PreTrainedModel): config_class = TinyTransformerConfig def __init__(self, config): super().__init__(config) self.num_labels = 1 self.transformer = TinyTransformer( config.vocab_size, config.embed_dim, config.num_heads, config.ff_dim, config.num_layers ) def forward(self, input_ids, attention_mask=None): outputs = self.transformer(input_ids) return {"logits": outputs} # Load models and tokenizers @st.cache_resource def load_models_and_tokenizers(hf_token): login(token=hf_token) device = torch.device("cpu") # forcing CPU as overhead of inference on GPU slows down the inference models = {} tokenizers = {} # Load Tiny-toxic-detector config = TinyTransformerConfig.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token) models["Tiny-toxic-detector"] = TinyTransformerForSequenceClassification.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", config=config, use_auth_token=hf_token).to(device) tokenizers["Tiny-toxic-detector"] = AutoTokenizer.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token) # Load other models model_configs = [ ("unitary/toxic-bert", AutoModelForSequenceClassification, "unitary/toxic-bert"), ("s-nlp/roberta_toxicity_classifier", AutoModelForSequenceClassification, "s-nlp/roberta_toxicity_classifier"), ("martin-ha/toxic-comment-model", AutoModelForSequenceClassification, "martin-ha/toxic-comment-model"), ("lmsys/toxicchat-t5-large-v1.0", AutoModelForSeq2SeqLM, "t5-large") ] for model_name, model_class, tokenizer_name in model_configs: models[model_name] = model_class.from_pretrained(model_name, use_auth_token=hf_token).to(device) tokenizers[model_name] = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=hf_token) return models, tokenizers, device # Prediction function def predict_toxicity(text, model, tokenizer, device, model_name): start_time = time.time() if model_name == "lmsys/toxicchat-t5-large-v1.0": prefix = "ToxicChat: " inputs = tokenizer.encode(prefix + text, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate(inputs, max_new_tokens=5) prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower() prediction = "Toxic" if prediction == "positive" else "Not Toxic" else: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding="max_length").to(device) if "token_type_ids" in inputs: del inputs["token_type_ids"] with torch.no_grad(): outputs = model(**inputs) if model_name == "Tiny-toxic-detector": logits = outputs["logits"].squeeze() prediction = "Toxic" if logits > 0.5 else "Not Toxic" else: logits = outputs.logits.squeeze() prediction = "Toxic" if logits[1] > logits[0] else "Not Toxic" end_time = time.time() inference_time = end_time - start_time return prediction, inference_time def main(): st.set_page_config(page_title="Multi-Model Toxicity Detector", layout="wide") st.title("Multi-Model Toxicity Detector") # Load models hf_token = os.getenv('AT') models, tokenizers, device = load_models_and_tokenizers(hf_token) # Reorder the models dictionary so that "Tiny-toxic-detector" is last model_names = sorted(models.keys(), key=lambda x: x == "Tiny-toxic-detector") # User input text = st.text_area("Enter text to classify:", height=150) if st.button("Classify"): if text: progress_bar = st.progress(0) results = [] for i, model_name in enumerate(model_names): with st.spinner(f"Classifying with {model_name}..."): prediction, inference_time = predict_toxicity(text, models[model_name], tokenizers[model_name], device, model_name) results.append((model_name, prediction, inference_time)) progress_bar.progress((i + 1) / len(model_names)) st.success("Classification complete!") progress_bar.empty() # Display results in a grid col1, col2, col3 = st.columns(3) for i, (model_name, prediction, inference_time) in enumerate(results): with [col1, col2, col3][i % 3]: st.subheader(model_name) st.write(f"Prediction: {prediction}") st.write(f"Inference Time: {inference_time:.4f}s") st.write("---") else: st.warning("Please enter some text to classify.") if __name__ == "__main__": main()