Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,11 +3,13 @@ import torch
|
|
3 |
from transformers import BertTokenizerFast, BertForTokenClassification
|
4 |
import gradio as gr
|
5 |
|
|
|
6 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
7 |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
|
8 |
model.eval()
|
9 |
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
|
|
|
11 |
id2label = {
|
12 |
0: 'O',
|
13 |
1: 'B-STEREO',
|
@@ -20,23 +22,26 @@ id2label = {
|
|
20 |
|
21 |
label2id = {v: k for k, v in id2label.items()}
|
22 |
|
|
|
23 |
label_colors = {
|
24 |
-
"STEREO": "rgba(255, 0, 0, 0.2)",
|
25 |
-
"GEN": "rgba(0, 0, 255, 0.2)",
|
26 |
-
"UNFAIR": "rgba(0, 255, 0, 0.2)"
|
27 |
}
|
28 |
|
|
|
29 |
def post_process_entities(result):
|
30 |
prev_entity_type = None
|
31 |
-
|
32 |
-
for i, token_data in enumerate(result):
|
33 |
labels = token_data["labels"]
|
34 |
-
|
35 |
labels = list(set(labels))
|
|
|
|
|
36 |
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
|
37 |
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
|
38 |
labels.remove(f"I-{entity_type}")
|
39 |
|
|
|
40 |
current_entity_type = None
|
41 |
current_label = None
|
42 |
for label in labels:
|
@@ -48,19 +53,18 @@ def post_process_entities(result):
|
|
48 |
if current_label.startswith("B-") and prev_entity_type == current_entity_type:
|
49 |
labels.remove(current_label)
|
50 |
labels.append(f"I-{current_entity_type}")
|
51 |
-
|
52 |
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
|
53 |
labels.remove(current_label)
|
54 |
labels.append(f"B-{current_entity_type}")
|
55 |
|
56 |
prev_entity_type = current_entity_type
|
57 |
else:
|
58 |
-
prev_entity_type = None
|
59 |
|
60 |
token_data["labels"] = labels
|
61 |
return result
|
62 |
|
63 |
-
|
64 |
def generate_json(sentence):
|
65 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
66 |
input_ids = inputs['input_ids'].to(model.device)
|
@@ -84,12 +88,13 @@ def generate_json(sentence):
|
|
84 |
|
85 |
return json.dumps(result, indent=4)
|
86 |
|
|
|
87 |
def predict_ner_tags_with_json(sentence):
|
88 |
json_result = generate_json(sentence)
|
89 |
|
90 |
result = json.loads(json_result)
|
91 |
|
92 |
-
word_row = []
|
93 |
stereo_row = []
|
94 |
gen_row = []
|
95 |
unfair_row = []
|
@@ -141,16 +146,30 @@ def predict_ner_tags_with_json(sentence):
|
|
141 |
|
142 |
return f"{matrix_html}<br><pre>{json_result}</pre>"
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
)
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from transformers import BertTokenizerFast, BertForTokenClassification
|
4 |
import gradio as gr
|
5 |
|
6 |
+
# Initialize tokenizer and model
|
7 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
8 |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
|
9 |
model.eval()
|
10 |
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
|
12 |
+
# Mapping IDs to labels
|
13 |
id2label = {
|
14 |
0: 'O',
|
15 |
1: 'B-STEREO',
|
|
|
22 |
|
23 |
label2id = {v: k for k, v in id2label.items()}
|
24 |
|
25 |
+
# Entity colors for highlights
|
26 |
label_colors = {
|
27 |
+
"STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
|
28 |
+
"GEN": "rgba(0, 0, 255, 0.2)", # Light Blue
|
29 |
+
"UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green
|
30 |
}
|
31 |
|
32 |
+
# Post-process entity tags
|
33 |
def post_process_entities(result):
|
34 |
prev_entity_type = None
|
35 |
+
for token_data in result:
|
|
|
36 |
labels = token_data["labels"]
|
|
|
37 |
labels = list(set(labels))
|
38 |
+
|
39 |
+
# Handle conflicting B- and I- tags for the same entity
|
40 |
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
|
41 |
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
|
42 |
labels.remove(f"I-{entity_type}")
|
43 |
|
44 |
+
# Handle sequence rules
|
45 |
current_entity_type = None
|
46 |
current_label = None
|
47 |
for label in labels:
|
|
|
53 |
if current_label.startswith("B-") and prev_entity_type == current_entity_type:
|
54 |
labels.remove(current_label)
|
55 |
labels.append(f"I-{current_entity_type}")
|
|
|
56 |
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
|
57 |
labels.remove(current_label)
|
58 |
labels.append(f"B-{current_entity_type}")
|
59 |
|
60 |
prev_entity_type = current_entity_type
|
61 |
else:
|
62 |
+
prev_entity_type = None
|
63 |
|
64 |
token_data["labels"] = labels
|
65 |
return result
|
66 |
|
67 |
+
# Generate JSON results
|
68 |
def generate_json(sentence):
|
69 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
70 |
input_ids = inputs['input_ids'].to(model.device)
|
|
|
88 |
|
89 |
return json.dumps(result, indent=4)
|
90 |
|
91 |
+
# Predict function
|
92 |
def predict_ner_tags_with_json(sentence):
|
93 |
json_result = generate_json(sentence)
|
94 |
|
95 |
result = json.loads(json_result)
|
96 |
|
97 |
+
word_row = []
|
98 |
stereo_row = []
|
99 |
gen_row = []
|
100 |
unfair_row = []
|
|
|
146 |
|
147 |
return f"{matrix_html}<br><pre>{json_result}</pre>"
|
148 |
|
149 |
+
# Gradio Interface
|
150 |
+
iface = gr.Blocks()
|
151 |
+
|
152 |
+
with iface:
|
153 |
+
with gr.Row():
|
154 |
+
gr.Markdown(
|
155 |
+
"""
|
156 |
+
# Social Bias Named Entity Recognition (with BERT) 🕵
|
157 |
+
Enter a sentence to predict biased parts of speech tags. This model uses multi-label `BertForTokenClassification` to label the entities:
|
158 |
+
- **Generalizations (GEN)**
|
159 |
+
- **Unfairness (UNFAIR)**
|
160 |
+
- **Stereotypes (STEREO)**
|
161 |
+
|
162 |
+
Labels follow the BIO format. Try it out!
|
163 |
+
|
164 |
+
- **[Blog Post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition)**
|
165 |
+
- **[Model Page](https://huggingface.co/maximuspowers/bias-detection-ner)**
|
166 |
+
"""
|
167 |
+
)
|
168 |
+
with gr.Row():
|
169 |
+
input_box = gr.Textbox(label="Input Sentence")
|
170 |
+
with gr.Row():
|
171 |
+
output_box = gr.HTML(label="Entity Matrix and JSON Output")
|
172 |
+
|
173 |
+
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
|
174 |
+
|
175 |
+
iface.launch(share=True)
|