Michielo's picture
Create app.py
ddaa9e6 verified
raw
history blame
6.6 kB
import streamlit as st
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
from huggingface_hub import login
import os
import time
# Model Architecture
class TinyTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = nn.Parameter(torch.zeros(1, 512, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(embed_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
x = self.transformer(x)
x = x.mean(dim=1) # Global average pooling
x = self.fc(x)
return self.sigmoid(x)
class TinyTransformerConfig(PretrainedConfig):
model_type = "tiny_transformer"
def __init__(
self,
vocab_size=30522,
embed_dim=64,
num_heads=2,
ff_dim=128,
num_layers=4,
max_position_embeddings=512,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.ff_dim = ff_dim
self.num_layers = num_layers
self.max_position_embeddings = max_position_embeddings
class TinyTransformerForSequenceClassification(PreTrainedModel):
config_class = TinyTransformerConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = 1
self.transformer = TinyTransformer(
config.vocab_size,
config.embed_dim,
config.num_heads,
config.ff_dim,
config.num_layers
)
def forward(self, input_ids, attention_mask=None):
outputs = self.transformer(input_ids)
return {"logits": outputs}
# Load models and tokenizers
@st.cache_resource
def load_models_and_tokenizers(hf_token):
login(token=hf_token)
device = torch.device("cpu") # forcing CPU as overhead of inference on GPU slows down the inference
models = {}
tokenizers = {}
# Load Tiny-toxic-detector
config = TinyTransformerConfig.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token)
models["Tiny-toxic-detector"] = TinyTransformerForSequenceClassification.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", config=config, use_auth_token=hf_token).to(device)
tokenizers["Tiny-toxic-detector"] = AutoTokenizer.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token)
# Load other models
model_configs = [
("unitary/toxic-bert", AutoModelForSequenceClassification, "unitary/toxic-bert"),
("s-nlp/roberta_toxicity_classifier", AutoModelForSequenceClassification, "s-nlp/roberta_toxicity_classifier"),
("martin-ha/toxic-comment-model", AutoModelForSequenceClassification, "martin-ha/toxic-comment-model"),
("lmsys/toxicchat-t5-large-v1.0", AutoModelForSeq2SeqLM, "t5-large")
]
for model_name, model_class, tokenizer_name in model_configs:
models[model_name] = model_class.from_pretrained(model_name, use_auth_token=hf_token).to(device)
tokenizers[model_name] = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=hf_token)
return models, tokenizers, device
# Prediction function
def predict_toxicity(text, model, tokenizer, device, model_name):
start_time = time.time()
if model_name == "lmsys/toxicchat-t5-large-v1.0":
prefix = "ToxicChat: "
inputs = tokenizer.encode(prefix + text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(inputs, max_new_tokens=5)
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
prediction = "Toxic" if prediction == "positive" else "Not Toxic"
else:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding="max_length").to(device)
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
with torch.no_grad():
outputs = model(**inputs)
if model_name == "Tiny-toxic-detector":
logits = outputs["logits"].squeeze()
prediction = "Toxic" if logits > 0.5 else "Not Toxic"
else:
logits = outputs.logits.squeeze()
prediction = "Toxic" if logits[1] > logits[0] else "Not Toxic"
end_time = time.time()
inference_time = end_time - start_time
return prediction, inference_time
def main():
st.set_page_config(page_title="Multi-Model Toxicity Detector", layout="wide")
st.title("Multi-Model Toxicity Detector")
# Load models
hf_token = os.getenv('AT')
models, tokenizers, device = load_models_and_tokenizers(hf_token)
# Reorder the models dictionary so that "Tiny-toxic-detector" is last
model_names = sorted(models.keys(), key=lambda x: x == "Tiny-toxic-detector")
# User input
text = st.text_area("Enter text to classify:", height=150)
if st.button("Classify"):
if text:
progress_bar = st.progress(0)
results = []
for i, model_name in enumerate(model_names):
with st.spinner(f"Classifying with {model_name}..."):
prediction, inference_time = predict_toxicity(text, models[model_name], tokenizers[model_name], device, model_name)
results.append((model_name, prediction, inference_time))
progress_bar.progress((i + 1) / len(model_names))
st.success("Classification complete!")
progress_bar.empty()
# Display results in a grid
col1, col2, col3 = st.columns(3)
for i, (model_name, prediction, inference_time) in enumerate(results):
with [col1, col2, col3][i % 3]:
st.subheader(model_name)
st.write(f"Prediction: {prediction}")
st.write(f"Inference Time: {inference_time:.4f}s")
st.write("---")
else:
st.warning("Please enter some text to classify.")
if __name__ == "__main__":
main()