Spaces:
Sleeping
Sleeping
File size: 3,678 Bytes
53a7262 9b562d8 7cd8165 9b562d8 9717ed1 7a0674a 9b562d8 7cd8165 9b562d8 7cd8165 939c704 7cd8165 939c704 7cd8165 9b562d8 7cd8165 9b562d8 7cd8165 939c704 9b562d8 939c704 9b562d8 939c704 9b562d8 939c704 9b562d8 7cd8165 9b562d8 5dfca2c 53a7262 2300f65 607049d 7a0674a 9b562d8 7cd8165 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr
# Initialize important things
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# IDs to labels we want to display
id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
# Color map for entities
label_colors = {
"STEREO": "rgba(255, 0, 0, 0.3)", # Red
"GEN": "rgba(0, 0, 255, 0.3)", # Blue
"UNFAIR": "rgba(0, 255, 0, 0.3)" # Green
}
# Helper to wrap a token in a span with color
def wrap_token_with_color(token, labels):
# Build nested highlights
style = "position: relative;"
for label in labels:
if label != "O":
style += f"background: {label_colors[label]};"
return f"<span style='{style}'>{token}</span>"
# Predict function
def predict_ner_tags(sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int() # Threshold
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
highlighted_sentence = ""
prev_labels = []
for i, token in enumerate(tokens):
if token not in tokenizer.all_special_tokens:
# Extract the labels for this token
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [id2label[idx.item()][2:] for idx in label_indices] if label_indices.numel() > 0 else ['O']
# Check if labels are the same as the previous token (for seamless highlighting)
if labels != prev_labels:
if prev_labels: # Close the previous span if needed
highlighted_sentence += "</span>"
# Start a new span
if labels != ["O"]:
highlighted_sentence += f"<span style='background: linear-gradient({', '.join([label_colors[label] for label in labels])})'>"
# Add the token to the span
highlighted_sentence += token.replace("##", "")
prev_labels = labels
# Close any open spans
if prev_labels and prev_labels != ["O"]:
highlighted_sentence += "</span>"
return highlighted_sentence
# Gradio Interface
iface = gr.Interface(
fn=predict_ner_tags,
inputs=gr.Textbox(label="Input Sentence"),
outputs=gr.HTML(label="Highlighted Sentence"),
title="Social Bias Named Entity Recognition (with BERT) 🕵",
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
"<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()
|