ifmain commited on
Commit
f9d050d
·
verified ·
1 Parent(s): 673e117

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -16
app.py CHANGED
@@ -1,25 +1,60 @@
1
- import gradio as gr
2
  import torch
3
- from moderation import *
 
 
 
 
 
 
4
 
 
 
5
 
6
- moderation = ModerationModel()
7
- moderation.load_state_dict(torch.load('moderation_model.pth', map_location=torch.device('cpu'))) # Remove map_location if run on gpu
8
- moderation.eval()
 
 
 
 
9
 
10
  def predict_moderation(text):
11
- embeddings_for_prediction = getEmb(text)
12
- prediction = predict(moderation, embeddings_for_prediction)
13
- category_scores = prediction.get('category_scores', {})
14
- detected = prediction.get('detected', False)
15
- return category_scores, str(detected)
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- iface = gr.Interface(fn=predict_moderation,
19
- inputs="text",
20
- outputs=[gr.Label(label="Category Scores"), gr.Label(label="Detected")],
21
- title="Moderation Model",
22
- description="Enter text to check for moderation flags.")
23
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- iface.launch()
 
1
+ import json
2
  import torch
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
+ import gradio as gr
5
+
6
+ model_name = "ifmain/ModerationBERT-En-02"
7
+
8
+ tokenizer = BertTokenizer.from_pretrained(model_name)
9
+ model = BertForSequenceClassification.from_pretrained(model_name, num_labels=18)
10
 
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ model.to(device)
13
 
14
+ categories = [
15
+ 'harassment', 'harassment_threatening', 'hate', 'hate_threatening',
16
+ 'self_harm', 'self_harm_instructions', 'self_harm_intent', 'sexual',
17
+ 'sexual_minors', 'violence', 'violence_graphic', 'self-harm',
18
+ 'sexual/minors', 'hate/threatening', 'violence/graphic',
19
+ 'self-harm/intent', 'self-harm/instructions', 'harassment/threatening'
20
+ ]
21
 
22
  def predict_moderation(text):
23
+ encoding = tokenizer.encode_plus(
24
+ text,
25
+ add_special_tokens=True,
26
+ max_length=128,
27
+ return_token_type_ids=False,
28
+ padding='max_length',
29
+ truncation=True,
30
+ return_attention_mask=True,
31
+ return_tensors='pt'
32
+ )
33
+
34
+ input_ids = encoding['input_ids'].to(device)
35
+ attention_mask = encoding['attention_mask'].to(device)
36
 
37
+ model.eval()
38
+ with torch.no_grad():
39
+ outputs = model(input_ids, attention_mask=attention_mask)
40
+
41
+ probs = torch.sigmoid(outputs.logits)[0].cpu().numpy()
42
+ category_scores = {categories[i]: float(probs[i]) for i in range(len(categories))}
43
+
44
+ detected = any(prob > 0.5 for prob in probs)
45
+
46
+ return category_scores, str(detected)
47
 
 
 
 
 
 
48
 
49
+ iface = gr.Interface(
50
+ fn=predict_moderation,
51
+ inputs=gr.Textbox(label="Enter text"),
52
+ outputs=[
53
+ gr.Label(label="Ratings by category"),
54
+ gr.Label(label="Was a violation detected?")
55
+ ],
56
+ title="Text moderation",
57
+ description="Enter text to check it for content violations (ModerationBERT-En-02 model)."
58
+ )
59
 
60
+ iface.launch()