Michielo commited on
Commit
ddaa9e6
·
verified ·
1 Parent(s): f754e10

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
5
+ from huggingface_hub import login
6
+ import os
7
+ import time
8
+
9
+ # Model Architecture
10
+ class TinyTransformer(nn.Module):
11
+ def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers):
12
+ super().__init__()
13
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
14
+ self.pos_encoding = nn.Parameter(torch.zeros(1, 512, embed_dim))
15
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True)
16
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
17
+ self.fc = nn.Linear(embed_dim, 1)
18
+ self.sigmoid = nn.Sigmoid()
19
+
20
+ def forward(self, x):
21
+ x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
22
+ x = self.transformer(x)
23
+ x = x.mean(dim=1) # Global average pooling
24
+ x = self.fc(x)
25
+ return self.sigmoid(x)
26
+
27
+ class TinyTransformerConfig(PretrainedConfig):
28
+ model_type = "tiny_transformer"
29
+
30
+ def __init__(
31
+ self,
32
+ vocab_size=30522,
33
+ embed_dim=64,
34
+ num_heads=2,
35
+ ff_dim=128,
36
+ num_layers=4,
37
+ max_position_embeddings=512,
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.vocab_size = vocab_size
42
+ self.embed_dim = embed_dim
43
+ self.num_heads = num_heads
44
+ self.ff_dim = ff_dim
45
+ self.num_layers = num_layers
46
+ self.max_position_embeddings = max_position_embeddings
47
+
48
+ class TinyTransformerForSequenceClassification(PreTrainedModel):
49
+ config_class = TinyTransformerConfig
50
+
51
+ def __init__(self, config):
52
+ super().__init__(config)
53
+ self.num_labels = 1
54
+ self.transformer = TinyTransformer(
55
+ config.vocab_size,
56
+ config.embed_dim,
57
+ config.num_heads,
58
+ config.ff_dim,
59
+ config.num_layers
60
+ )
61
+
62
+ def forward(self, input_ids, attention_mask=None):
63
+ outputs = self.transformer(input_ids)
64
+ return {"logits": outputs}
65
+
66
+ # Load models and tokenizers
67
+ @st.cache_resource
68
+ def load_models_and_tokenizers(hf_token):
69
+ login(token=hf_token)
70
+ device = torch.device("cpu") # forcing CPU as overhead of inference on GPU slows down the inference
71
+
72
+ models = {}
73
+ tokenizers = {}
74
+
75
+ # Load Tiny-toxic-detector
76
+ config = TinyTransformerConfig.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token)
77
+ models["Tiny-toxic-detector"] = TinyTransformerForSequenceClassification.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", config=config, use_auth_token=hf_token).to(device)
78
+ tokenizers["Tiny-toxic-detector"] = AutoTokenizer.from_pretrained("AssistantsLab/Tiny-Toxic-Detector", use_auth_token=hf_token)
79
+
80
+ # Load other models
81
+ model_configs = [
82
+ ("unitary/toxic-bert", AutoModelForSequenceClassification, "unitary/toxic-bert"),
83
+ ("s-nlp/roberta_toxicity_classifier", AutoModelForSequenceClassification, "s-nlp/roberta_toxicity_classifier"),
84
+ ("martin-ha/toxic-comment-model", AutoModelForSequenceClassification, "martin-ha/toxic-comment-model"),
85
+ ("lmsys/toxicchat-t5-large-v1.0", AutoModelForSeq2SeqLM, "t5-large")
86
+ ]
87
+
88
+ for model_name, model_class, tokenizer_name in model_configs:
89
+ models[model_name] = model_class.from_pretrained(model_name, use_auth_token=hf_token).to(device)
90
+ tokenizers[model_name] = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=hf_token)
91
+
92
+ return models, tokenizers, device
93
+
94
+ # Prediction function
95
+ def predict_toxicity(text, model, tokenizer, device, model_name):
96
+ start_time = time.time()
97
+
98
+ if model_name == "lmsys/toxicchat-t5-large-v1.0":
99
+ prefix = "ToxicChat: "
100
+ inputs = tokenizer.encode(prefix + text, return_tensors="pt").to(device)
101
+
102
+ with torch.no_grad():
103
+ outputs = model.generate(inputs, max_new_tokens=5)
104
+
105
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
106
+ prediction = "Toxic" if prediction == "positive" else "Not Toxic"
107
+ else:
108
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding="max_length").to(device)
109
+
110
+ if "token_type_ids" in inputs:
111
+ del inputs["token_type_ids"]
112
+
113
+ with torch.no_grad():
114
+ outputs = model(**inputs)
115
+
116
+ if model_name == "Tiny-toxic-detector":
117
+ logits = outputs["logits"].squeeze()
118
+ prediction = "Toxic" if logits > 0.5 else "Not Toxic"
119
+ else:
120
+ logits = outputs.logits.squeeze()
121
+ prediction = "Toxic" if logits[1] > logits[0] else "Not Toxic"
122
+
123
+ end_time = time.time()
124
+ inference_time = end_time - start_time
125
+
126
+ return prediction, inference_time
127
+
128
+ def main():
129
+ st.set_page_config(page_title="Multi-Model Toxicity Detector", layout="wide")
130
+ st.title("Multi-Model Toxicity Detector")
131
+
132
+ # Load models
133
+ hf_token = os.getenv('AT')
134
+ models, tokenizers, device = load_models_and_tokenizers(hf_token)
135
+
136
+ # Reorder the models dictionary so that "Tiny-toxic-detector" is last
137
+ model_names = sorted(models.keys(), key=lambda x: x == "Tiny-toxic-detector")
138
+
139
+ # User input
140
+ text = st.text_area("Enter text to classify:", height=150)
141
+
142
+ if st.button("Classify"):
143
+ if text:
144
+ progress_bar = st.progress(0)
145
+ results = []
146
+
147
+ for i, model_name in enumerate(model_names):
148
+ with st.spinner(f"Classifying with {model_name}..."):
149
+ prediction, inference_time = predict_toxicity(text, models[model_name], tokenizers[model_name], device, model_name)
150
+ results.append((model_name, prediction, inference_time))
151
+ progress_bar.progress((i + 1) / len(model_names))
152
+
153
+ st.success("Classification complete!")
154
+ progress_bar.empty()
155
+
156
+ # Display results in a grid
157
+ col1, col2, col3 = st.columns(3)
158
+ for i, (model_name, prediction, inference_time) in enumerate(results):
159
+ with [col1, col2, col3][i % 3]:
160
+ st.subheader(model_name)
161
+ st.write(f"Prediction: {prediction}")
162
+ st.write(f"Inference Time: {inference_time:.4f}s")
163
+ st.write("---")
164
+ else:
165
+ st.warning("Please enter some text to classify.")
166
+
167
+ if __name__ == "__main__":
168
+ main()