File size: 3,678 Bytes
53a7262
9b562d8
 
 
 
7cd8165
9b562d8
9717ed1
7a0674a
 
9b562d8
7cd8165
9b562d8
 
 
 
 
 
 
 
 
 
7cd8165
 
939c704
 
 
7cd8165
 
939c704
 
 
 
 
 
 
 
 
7cd8165
9b562d8
 
 
 
 
 
 
 
 
7cd8165
9b562d8
 
7cd8165
939c704
 
9b562d8
 
939c704
9b562d8
939c704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b562d8
939c704
9b562d8
7cd8165
9b562d8
 
5dfca2c
 
53a7262
2300f65
 
607049d
7a0674a
9b562d8
 
 
7cd8165
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr

# Initialize important things
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')

# IDs to labels we want to display
id2label = {
    0: 'O',
    1: 'B-STEREO',
    2: 'I-STEREO',
    3: 'B-GEN',
    4: 'I-GEN',
    5: 'B-UNFAIR',
    6: 'I-UNFAIR'
}

# Color map for entities
label_colors = {
    "STEREO": "rgba(255, 0, 0, 0.3)",  # Red
    "GEN": "rgba(0, 0, 255, 0.3)",     # Blue
    "UNFAIR": "rgba(0, 255, 0, 0.3)"   # Green
}

# Helper to wrap a token in a span with color
def wrap_token_with_color(token, labels):
    # Build nested highlights
    style = "position: relative;"
    for label in labels:
        if label != "O":
            style += f"background: {label_colors[label]};"
    return f"<span style='{style}'>{token}</span>"

# Predict function
def predict_ner_tags(sentence):
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.sigmoid(logits)
        predicted_labels = (probabilities > 0.5).int()  # Threshold

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    highlighted_sentence = ""
    prev_labels = []

    for i, token in enumerate(tokens):
        if token not in tokenizer.all_special_tokens:
            # Extract the labels for this token
            label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
            labels = [id2label[idx.item()][2:] for idx in label_indices] if label_indices.numel() > 0 else ['O']
            
            # Check if labels are the same as the previous token (for seamless highlighting)
            if labels != prev_labels:
                if prev_labels:  # Close the previous span if needed
                    highlighted_sentence += "</span>"

                # Start a new span
                if labels != ["O"]:
                    highlighted_sentence += f"<span style='background: linear-gradient({', '.join([label_colors[label] for label in labels])})'>"
            
            # Add the token to the span
            highlighted_sentence += token.replace("##", "")
            prev_labels = labels

    # Close any open spans
    if prev_labels and prev_labels != ["O"]:
        highlighted_sentence += "</span>"

    return highlighted_sentence

# Gradio Interface
iface = gr.Interface(
    fn=predict_ner_tags,
    inputs=gr.Textbox(label="Input Sentence"),
    outputs=gr.HTML(label="Highlighted Sentence"),
    title="Social Bias Named Entity Recognition (with BERT) 🕵",
    description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
                 "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
                 "<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."),
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()