Spaces:
Sleeping
Sleeping
File size: 5,780 Bytes
53a7262 9b562d8 9717ed1 7a0674a 9b562d8 2c53668 7cd8165 2c53668 7cd8165 2c53668 9b562d8 2c53668 9b562d8 2c53668 9b562d8 2c53668 9b562d8 2c53668 53a7262 2300f65 607049d 7a0674a 9b562d8 5b65826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
label2id = {v: k for k, v in id2label.items()}
label_colors = {
"STEREO": "rgba(255, 0, 0, 0.2)",
"GEN": "rgba(0, 0, 255, 0.2)",
"UNFAIR": "rgba(0, 255, 0, 0.2)"
}
def post_process_entities(result):
prev_entity_type = None
for i, token_data in enumerate(result):
labels = token_data["labels"]
labels = list(set(labels))
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
labels.remove(f"I-{entity_type}")
current_entity_type = None
current_label = None
for label in labels:
if label.startswith("B-") or label.startswith("I-"):
current_label = label
current_entity_type = label[2:]
if current_entity_type:
if current_label.startswith("B-") and prev_entity_type == current_entity_type:
labels.remove(current_label)
labels.append(f"I-{current_entity_type}")
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
labels.remove(current_label)
labels.append(f"B-{current_entity_type}")
prev_entity_type = current_entity_type
else:
prev_entity_type = None
token_data["labels"] = labels
return result
def generate_json(sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
result = []
for i, token in enumerate(tokens):
if token not in tokenizer.all_special_tokens:
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
result.append({"token": token.replace("##", ""), "labels": labels})
result = post_process_entities(result)
return json.dumps(result, indent=4)
def predict_ner_tags_with_json(sentence):
json_result = generate_json(sentence)
result = json.loads(json_result)
word_row = []
stereo_row = []
gen_row = []
unfair_row = []
for token_data in result:
token = token_data["token"]
labels = token_data["labels"]
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
stereo_labels = [label[2:] for label in labels if "STEREO" in label]
stereo_row.append(
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
if stereo_labels else " "
)
gen_labels = [label[2:] for label in labels if "GEN" in label]
gen_row.append(
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
if gen_labels else " "
)
unfair_labels = [label[2:] for label in labels if "UNFAIR" in label]
unfair_row.append(
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
if unfair_labels else " "
)
matrix_html = f"""
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
<tr>
<td><strong>Text Sequence</strong></td>
{''.join(f"<td>{word}</td>" for word in word_row)}
</tr>
<tr>
<td><strong>Generalizations</strong></td>
{''.join(f"<td>{cell}</td>" for cell in gen_row)}
</tr>
<tr>
<td><strong>Unfairness</strong></td>
{''.join(f"<td>{cell}</td>" for cell in unfair_row)}
</tr>
<tr>
<td><strong>Stereotypes</strong></td>
{''.join(f"<td>{cell}</td>" for cell in stereo_row)}
</tr>
</table>
"""
return f"{matrix_html}<br><pre>{json_result}</pre>"
iface = gr.Interface(
fn=predict_ner_tags_with_json,
inputs=[gr.Textbox(label="Input Sentence")],
outputs=[gr.HTML(label="Entity Matrix and JSON Output")],
title="Social Bias Named Entity Recognition (with BERT) 🕵",
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
"<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(share=True)
|