|
import torch |
|
from transformers import BertTokenizerFast, BertForTokenClassification |
|
import gradio as gr |
|
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner') |
|
model.eval() |
|
model.to('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
id2label = { |
|
0: 'O', |
|
1: 'B-STEREO', |
|
2: 'I-STEREO', |
|
3: 'B-GEN', |
|
4: 'I-GEN', |
|
5: 'B-UNFAIR', |
|
6: 'I-UNFAIR' |
|
} |
|
|
|
def predict_ner_tags(sentence): |
|
|
|
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
input_ids = inputs['input_ids'].to(model.device) |
|
attention_mask = inputs['attention_mask'].to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
probabilities = torch.sigmoid(logits) |
|
predicted_labels = (probabilities > 0.5).int() |
|
|
|
result = [] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
for i, token in enumerate(tokens): |
|
if token not in tokenizer.all_special_tokens: |
|
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1) |
|
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O'] |
|
result.append((token, labels)) |
|
|
|
return result |
|
|
|
def format_output(result): |
|
|
|
formatted_output = "<div style='font-family: Arial;'>" |
|
for token, labels in result: |
|
styles = [] |
|
if "B-STEREO" in labels or "I-STEREO" in labels: |
|
styles.append("border-bottom: 2px solid blue;") |
|
if "B-GEN" in labels or "I-GEN" in labels: |
|
styles.append("background-color: green; color: white;") |
|
if "B-UNFAIR" in labels or "I-UNFAIR" in labels: |
|
styles.append("border: 2px dashed red;") |
|
|
|
style_string = " ".join(styles) if styles else "" |
|
formatted_output += f"<span style='{style_string} padding: 3px; margin: 2px;'>{token}</span> " |
|
formatted_output += "</div>" |
|
return formatted_output |
|
|
|
iface = gr.Interface( |
|
fn=predict_ner_tags, |
|
inputs="text", |
|
outputs=gr.outputs.HTML(label="Output"), |
|
title="Named Entity Recognition with BERT", |
|
description="Enter a sentence to predict NER tags using a BERT model trained for multi-label classification. Different styles represent different entity types.", |
|
examples=["Tall men are so clumsy."], |
|
allow_flagging="never", |
|
theme="default" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|