Spaces:
Sleeping
Sleeping
File size: 6,152 Bytes
53a7262 9b562d8 34ab835 9b562d8 9717ed1 7a0674a 9b562d8 34ab835 9b562d8 2c53668 34ab835 7cd8165 34ab835 7cd8165 34ab835 2c53668 34ab835 2c53668 34ab835 2c53668 34ab835 2c53668 34ab835 2c53668 34ab835 2c53668 9b562d8 2c53668 9b562d8 2c53668 9b562d8 2c53668 34ab835 2c53668 34ab835 2c53668 34ab835 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr
# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Mapping IDs to labels
id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
label2id = {v: k for k, v in id2label.items()}
# Entity colors for highlights
label_colors = {
"STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
"GEN": "rgba(0, 0, 255, 0.2)", # Light Blue
"UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green
}
# Post-process entity tags
def post_process_entities(result):
prev_entity_type = None
for token_data in result:
labels = token_data["labels"]
labels = list(set(labels))
# Handle conflicting B- and I- tags for the same entity
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
labels.remove(f"I-{entity_type}")
# Handle sequence rules
current_entity_type = None
current_label = None
for label in labels:
if label.startswith("B-") or label.startswith("I-"):
current_label = label
current_entity_type = label[2:]
if current_entity_type:
if current_label.startswith("B-") and prev_entity_type == current_entity_type:
labels.remove(current_label)
labels.append(f"I-{current_entity_type}")
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
labels.remove(current_label)
labels.append(f"B-{current_entity_type}")
prev_entity_type = current_entity_type
else:
prev_entity_type = None
token_data["labels"] = labels
return result
# Generate JSON results
def generate_json(sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
result = []
for i, token in enumerate(tokens):
if token not in tokenizer.all_special_tokens:
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
result.append({"token": token.replace("##", ""), "labels": labels})
result = post_process_entities(result)
return json.dumps(result, indent=4)
# Predict function
def predict_ner_tags_with_json(sentence):
json_result = generate_json(sentence)
result = json.loads(json_result)
word_row = []
stereo_row = []
gen_row = []
unfair_row = []
for token_data in result:
token = token_data["token"]
labels = token_data["labels"]
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
stereo_labels = [label[2:] for label in labels if "STEREO" in label]
stereo_row.append(
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
if stereo_labels else " "
)
gen_labels = [label[2:] for label in labels if "GEN" in label]
gen_row.append(
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
if gen_labels else " "
)
unfair_labels = [label[2:] for label in labels if "UNFAIR" in label]
unfair_row.append(
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
if unfair_labels else " "
)
matrix_html = f"""
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
<tr>
<td><strong>Text Sequence</strong></td>
{''.join(f"<td>{word}</td>" for word in word_row)}
</tr>
<tr>
<td><strong>Generalizations</strong></td>
{''.join(f"<td>{cell}</td>" for cell in gen_row)}
</tr>
<tr>
<td><strong>Unfairness</strong></td>
{''.join(f"<td>{cell}</td>" for cell in unfair_row)}
</tr>
<tr>
<td><strong>Stereotypes</strong></td>
{''.join(f"<td>{cell}</td>" for cell in stereo_row)}
</tr>
</table>
"""
return f"{matrix_html}<br><pre>{json_result}</pre>"
# Gradio Interface
iface = gr.Blocks()
with iface:
with gr.Row():
gr.Markdown(
"""
# Social Bias Named Entity Recognition (with BERT) 🕵
Enter a sentence to predict biased parts of speech tags. This model uses multi-label `BertForTokenClassification` to label the entities:
- **Generalizations (GEN)**
- **Unfairness (UNFAIR)**
- **Stereotypes (STEREO)**
Labels follow the BIO format. Try it out!
- **[Blog Post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition)**
- **[Model Page](https://huggingface.co/maximuspowers/bias-detection-ner)**
"""
)
with gr.Row():
input_box = gr.Textbox(label="Input Sentence")
with gr.Row():
output_box = gr.HTML(label="Entity Matrix and JSON Output")
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
iface.launch(share=True)
|