File size: 2,605 Bytes
9b562d8
 
 
 
7a0674a
9b562d8
9717ed1
7a0674a
 
9b562d8
b75c854
9b562d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b75c854
9b562d8
b75c854
 
 
 
 
 
 
 
 
 
 
9b562d8
 
 
 
 
7a0674a
9b562d8
b75c854
9b562d8
7a0674a
9b562d8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr

# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')

# Define label mappings
id2label = {
    0: 'O',
    1: 'B-STEREO',
    2: 'I-STEREO',
    3: 'B-GEN',
    4: 'I-GEN',
    5: 'B-UNFAIR',
    6: 'I-UNFAIR'
}

def predict_ner_tags(sentence):
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.sigmoid(logits)
        predicted_labels = (probabilities > 0.5).int()

    result = []
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    for i, token in enumerate(tokens):
        if token not in tokenizer.all_special_tokens:
            label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
            labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
            result.append((token, labels))

    return result

def format_output(result):
    formatted_output = "<div style='font-family: Arial;'>"
    for token, labels in result:
        styles = []
        if "B-STEREO" in labels or "I-STEREO" in labels:
            styles.append("border-bottom: 2px solid blue;")
        if "B-GEN" in labels or "I-GEN" in labels:
            styles.append("background-color: green; color: white;")
        if "B-UNFAIR" in labels or "I-UNFAIR" in labels:
            styles.append("border: 2px dashed red;")
        
        style_string = " ".join(styles) if styles else ""
        formatted_output += f"<span style='{style_string} padding: 3px; margin: 2px;'>{token}</span> "
    formatted_output += "</div>"
    return formatted_output

iface = gr.Interface(
    fn=predict_ner_tags,
    inputs="text",
    outputs="html",  # Directly use "html" here
    title="Named Entity Recognition with BERT",
    description="Enter a sentence to predict NER tags using a BERT model trained for multi-label classification. Different styles represent different entity types.",
    examples=["Tall men are so clumsy."],
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()