Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from transformers import BertTokenizerFast, BertForTokenClassification | |
import gradio as gr | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner') | |
model.eval() | |
model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
id2label = { | |
0: 'O', | |
1: 'B-STEREO', | |
2: 'I-STEREO', | |
3: 'B-GEN', | |
4: 'I-GEN', | |
5: 'B-UNFAIR', | |
6: 'I-UNFAIR' | |
} | |
label2id = {v: k for k, v in id2label.items()} | |
label_colors = { | |
"STEREO": "rgba(255, 0, 0, 0.2)", | |
"GEN": "rgba(0, 0, 255, 0.2)", | |
"UNFAIR": "rgba(0, 255, 0, 0.2)" | |
} | |
def post_process_entities(result): | |
prev_entity_type = None | |
for i, token_data in enumerate(result): | |
labels = token_data["labels"] | |
labels = list(set(labels)) | |
for entity_type in ["GEN", "UNFAIR", "STEREO"]: | |
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels: | |
labels.remove(f"I-{entity_type}") | |
current_entity_type = None | |
current_label = None | |
for label in labels: | |
if label.startswith("B-") or label.startswith("I-"): | |
current_label = label | |
current_entity_type = label[2:] | |
if current_entity_type: | |
if current_label.startswith("B-") and prev_entity_type == current_entity_type: | |
labels.remove(current_label) | |
labels.append(f"I-{current_entity_type}") | |
if current_label.startswith("I-") and prev_entity_type != current_entity_type: | |
labels.remove(current_label) | |
labels.append(f"B-{current_entity_type}") | |
prev_entity_type = current_entity_type | |
else: | |
prev_entity_type = None | |
token_data["labels"] = labels | |
return result | |
def generate_json(sentence): | |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
input_ids = inputs['input_ids'].to(model.device) | |
attention_mask = inputs['attention_mask'].to(model.device) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
probabilities = torch.sigmoid(logits) | |
predicted_labels = (probabilities > 0.5).int() | |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
result = [] | |
for i, token in enumerate(tokens): | |
if token not in tokenizer.all_special_tokens: | |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1) | |
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O'] | |
result.append({"token": token.replace("##", ""), "labels": labels}) | |
result = post_process_entities(result) | |
return json.dumps(result, indent=4) | |
def predict_ner_tags_with_json(sentence): | |
json_result = generate_json(sentence) | |
result = json.loads(json_result) | |
word_row = [] | |
stereo_row = [] | |
gen_row = [] | |
unfair_row = [] | |
for token_data in result: | |
token = token_data["token"] | |
labels = token_data["labels"] | |
word_row.append(f"<span style='font-weight:bold;'>{token}</span>") | |
stereo_labels = [label[2:] for label in labels if "STEREO" in label] | |
stereo_row.append( | |
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>" | |
if stereo_labels else " " | |
) | |
gen_labels = [label[2:] for label in labels if "GEN" in label] | |
gen_row.append( | |
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>" | |
if gen_labels else " " | |
) | |
unfair_labels = [label[2:] for label in labels if "UNFAIR" in label] | |
unfair_row.append( | |
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>" | |
if unfair_labels else " " | |
) | |
matrix_html = f""" | |
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'> | |
<tr> | |
<td><strong>Text Sequence</strong></td> | |
{''.join(f"<td>{word}</td>" for word in word_row)} | |
</tr> | |
<tr> | |
<td><strong>Generalizations</strong></td> | |
{''.join(f"<td>{cell}</td>" for cell in gen_row)} | |
</tr> | |
<tr> | |
<td><strong>Unfairness</strong></td> | |
{''.join(f"<td>{cell}</td>" for cell in unfair_row)} | |
</tr> | |
<tr> | |
<td><strong>Stereotypes</strong></td> | |
{''.join(f"<td>{cell}</td>" for cell in stereo_row)} | |
</tr> | |
</table> | |
""" | |
return f"{matrix_html}<br><pre>{json_result}</pre>" | |
iface = gr.Interface( | |
fn=predict_ner_tags_with_json, | |
inputs=[gr.Textbox(label="Input Sentence")], | |
outputs=[gr.HTML(label="Entity Matrix and JSON Output")], | |
title="Social Bias Named Entity Recognition (with BERT) 🕵", | |
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)." | |
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>." | |
"<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."), | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) | |