maximuspowers commited on
Commit
34ab835
·
verified ·
1 Parent(s): 2c53668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -23
app.py CHANGED
@@ -3,11 +3,13 @@ import torch
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
 
6
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
7
  model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
8
  model.eval()
9
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
10
 
 
11
  id2label = {
12
  0: 'O',
13
  1: 'B-STEREO',
@@ -20,23 +22,26 @@ id2label = {
20
 
21
  label2id = {v: k for k, v in id2label.items()}
22
 
 
23
  label_colors = {
24
- "STEREO": "rgba(255, 0, 0, 0.2)",
25
- "GEN": "rgba(0, 0, 255, 0.2)",
26
- "UNFAIR": "rgba(0, 255, 0, 0.2)"
27
  }
28
 
 
29
  def post_process_entities(result):
30
  prev_entity_type = None
31
-
32
- for i, token_data in enumerate(result):
33
  labels = token_data["labels"]
34
-
35
  labels = list(set(labels))
 
 
36
  for entity_type in ["GEN", "UNFAIR", "STEREO"]:
37
  if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
38
  labels.remove(f"I-{entity_type}")
39
 
 
40
  current_entity_type = None
41
  current_label = None
42
  for label in labels:
@@ -48,19 +53,18 @@ def post_process_entities(result):
48
  if current_label.startswith("B-") and prev_entity_type == current_entity_type:
49
  labels.remove(current_label)
50
  labels.append(f"I-{current_entity_type}")
51
-
52
  if current_label.startswith("I-") and prev_entity_type != current_entity_type:
53
  labels.remove(current_label)
54
  labels.append(f"B-{current_entity_type}")
55
 
56
  prev_entity_type = current_entity_type
57
  else:
58
- prev_entity_type = None
59
 
60
  token_data["labels"] = labels
61
  return result
62
 
63
-
64
  def generate_json(sentence):
65
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
66
  input_ids = inputs['input_ids'].to(model.device)
@@ -84,12 +88,13 @@ def generate_json(sentence):
84
 
85
  return json.dumps(result, indent=4)
86
 
 
87
  def predict_ner_tags_with_json(sentence):
88
  json_result = generate_json(sentence)
89
 
90
  result = json.loads(json_result)
91
 
92
- word_row = []
93
  stereo_row = []
94
  gen_row = []
95
  unfair_row = []
@@ -141,16 +146,30 @@ def predict_ner_tags_with_json(sentence):
141
 
142
  return f"{matrix_html}<br><pre>{json_result}</pre>"
143
 
144
- iface = gr.Interface(
145
- fn=predict_ner_tags_with_json,
146
- inputs=[gr.Textbox(label="Input Sentence")],
147
- outputs=[gr.HTML(label="Entity Matrix and JSON Output")],
148
- title="Social Bias Named Entity Recognition (with BERT) 🕵",
149
- description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
150
- "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
151
- "<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."),
152
- allow_flagging="never"
153
- )
154
-
155
- if __name__ == "__main__":
156
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
6
+ # Initialize tokenizer and model
7
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
  model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
9
  model.eval()
10
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # Mapping IDs to labels
13
  id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
 
22
 
23
  label2id = {v: k for k, v in id2label.items()}
24
 
25
+ # Entity colors for highlights
26
  label_colors = {
27
+ "STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
28
+ "GEN": "rgba(0, 0, 255, 0.2)", # Light Blue
29
+ "UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green
30
  }
31
 
32
+ # Post-process entity tags
33
  def post_process_entities(result):
34
  prev_entity_type = None
35
+ for token_data in result:
 
36
  labels = token_data["labels"]
 
37
  labels = list(set(labels))
38
+
39
+ # Handle conflicting B- and I- tags for the same entity
40
  for entity_type in ["GEN", "UNFAIR", "STEREO"]:
41
  if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
42
  labels.remove(f"I-{entity_type}")
43
 
44
+ # Handle sequence rules
45
  current_entity_type = None
46
  current_label = None
47
  for label in labels:
 
53
  if current_label.startswith("B-") and prev_entity_type == current_entity_type:
54
  labels.remove(current_label)
55
  labels.append(f"I-{current_entity_type}")
 
56
  if current_label.startswith("I-") and prev_entity_type != current_entity_type:
57
  labels.remove(current_label)
58
  labels.append(f"B-{current_entity_type}")
59
 
60
  prev_entity_type = current_entity_type
61
  else:
62
+ prev_entity_type = None
63
 
64
  token_data["labels"] = labels
65
  return result
66
 
67
+ # Generate JSON results
68
  def generate_json(sentence):
69
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
70
  input_ids = inputs['input_ids'].to(model.device)
 
88
 
89
  return json.dumps(result, indent=4)
90
 
91
+ # Predict function
92
  def predict_ner_tags_with_json(sentence):
93
  json_result = generate_json(sentence)
94
 
95
  result = json.loads(json_result)
96
 
97
+ word_row = []
98
  stereo_row = []
99
  gen_row = []
100
  unfair_row = []
 
146
 
147
  return f"{matrix_html}<br><pre>{json_result}</pre>"
148
 
149
+ # Gradio Interface
150
+ iface = gr.Blocks()
151
+
152
+ with iface:
153
+ with gr.Row():
154
+ gr.Markdown(
155
+ """
156
+ # Social Bias Named Entity Recognition (with BERT) 🕵
157
+ Enter a sentence to predict biased parts of speech tags. This model uses multi-label `BertForTokenClassification` to label the entities:
158
+ - **Generalizations (GEN)**
159
+ - **Unfairness (UNFAIR)**
160
+ - **Stereotypes (STEREO)**
161
+
162
+ Labels follow the BIO format. Try it out!
163
+
164
+ - **[Blog Post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition)**
165
+ - **[Model Page](https://huggingface.co/maximuspowers/bias-detection-ner)**
166
+ """
167
+ )
168
+ with gr.Row():
169
+ input_box = gr.Textbox(label="Input Sentence")
170
+ with gr.Row():
171
+ output_box = gr.HTML(label="Entity Matrix and JSON Output")
172
+
173
+ input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
174
+
175
+ iface.launch(share=True)