Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from transformers import BertTokenizerFast, BertForTokenClassification | |
import gradio as gr | |
# Initialize important things | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner') | |
model.eval() | |
model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
# IDs to labels we want to display | |
id2label = { | |
0: 'O', | |
1: 'B-STEREO', | |
2: 'I-STEREO', | |
3: 'B-GEN', | |
4: 'I-GEN', | |
5: 'B-UNFAIR', | |
6: 'I-UNFAIR' | |
} | |
# Color map for entities | |
label_colors = { | |
"STEREO": "rgba(255, 0, 0, 0.3)", # Red | |
"GEN": "rgba(0, 0, 255, 0.3)", # Blue | |
"UNFAIR": "rgba(0, 255, 0, 0.3)" # Green | |
} | |
# Helper to wrap a token in a span with color | |
def wrap_token_with_color(token, labels): | |
# Build nested highlights | |
style = "position: relative;" | |
for label in labels: | |
if label != "O" and label in label_colors: | |
style += f"background: {label_colors[label]};" | |
return f"<span style='{style}'>{token}</span>" | |
# Predict function | |
def predict_ner_tags(sentence): | |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
input_ids = inputs['input_ids'].to(model.device) | |
attention_mask = inputs['attention_mask'].to(model.device) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
probabilities = torch.sigmoid(logits) | |
predicted_labels = (probabilities > 0.5).int() # Threshold | |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
highlighted_sentence = "" | |
prev_labels = [] | |
for i, token in enumerate(tokens): | |
if token not in tokenizer.all_special_tokens: | |
# Extract the labels for this token | |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1) | |
labels = [id2label[idx.item()][2:] for idx in label_indices if idx.item() in id2label] # Safe lookup | |
if not labels: # Handle empty labels gracefully | |
labels = ["O"] | |
# Check if labels are the same as the previous token (for seamless highlighting) | |
if labels != prev_labels: | |
if prev_labels: # Close the previous span if needed | |
highlighted_sentence += "</span>" | |
# Start a new span | |
if labels != ["O"]: | |
highlight_colors = [label_colors[label] for label in labels if label in label_colors] | |
if highlight_colors: # Only create gradient if valid colors exist | |
highlighted_sentence += f"<span style='background: linear-gradient({', '.join(highlight_colors)});'>" | |
# Add the token to the span | |
highlighted_sentence += token.replace("##", "") | |
prev_labels = labels | |
# Close any open spans | |
if prev_labels and prev_labels != ["O"]: | |
highlighted_sentence += "</span>" | |
return highlighted_sentence | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=predict_ner_tags, | |
inputs=gr.Textbox(label="Input Sentence"), | |
outputs=gr.HTML(label="Highlighted Sentence"), | |
title="Social Bias Named Entity Recognition (with BERT) π΅", | |
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)." | |
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>." | |
"<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."), | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) | |