maximuspowers's picture
Update app.py
5b65826 verified
raw
history blame
3.99 kB
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr
# Initialize important things
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# IDs to labels we want to display
id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
# Color map for entities
label_colors = {
"STEREO": "rgba(255, 0, 0, 0.3)", # Red
"GEN": "rgba(0, 0, 255, 0.3)", # Blue
"UNFAIR": "rgba(0, 255, 0, 0.3)" # Green
}
# Helper to wrap a token in a span with color
def wrap_token_with_color(token, labels):
# Build nested highlights
style = "position: relative;"
for label in labels:
if label != "O" and label in label_colors:
style += f"background: {label_colors[label]};"
return f"<span style='{style}'>{token}</span>"
# Predict function
def predict_ner_tags(sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int() # Threshold
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
highlighted_sentence = ""
prev_labels = []
for i, token in enumerate(tokens):
if token not in tokenizer.all_special_tokens:
# Extract the labels for this token
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [id2label[idx.item()][2:] for idx in label_indices if idx.item() in id2label] # Safe lookup
if not labels: # Handle empty labels gracefully
labels = ["O"]
# Check if labels are the same as the previous token (for seamless highlighting)
if labels != prev_labels:
if prev_labels: # Close the previous span if needed
highlighted_sentence += "</span>"
# Start a new span
if labels != ["O"]:
highlight_colors = [label_colors[label] for label in labels if label in label_colors]
if highlight_colors: # Only create gradient if valid colors exist
highlighted_sentence += f"<span style='background: linear-gradient({', '.join(highlight_colors)});'>"
# Add the token to the span
highlighted_sentence += token.replace("##", "")
prev_labels = labels
# Close any open spans
if prev_labels and prev_labels != ["O"]:
highlighted_sentence += "</span>"
return highlighted_sentence
# Gradio Interface
iface = gr.Interface(
fn=predict_ner_tags,
inputs=gr.Textbox(label="Input Sentence"),
outputs=gr.HTML(label="Highlighted Sentence"),
title="Social Bias Named Entity Recognition (with BERT) πŸ•΅",
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
"<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(share=True)