maximuspowers's picture
Update app.py
7a0674a verified
raw
history blame
2.61 kB
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr
# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Define label mappings
id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
def predict_ner_tags(sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int()
result = []
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
for i, token in enumerate(tokens):
if token not in tokenizer.all_special_tokens:
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
result.append((token, labels))
return result
def format_output(result):
formatted_output = "<div style='font-family: Arial;'>"
for token, labels in result:
styles = []
if "B-STEREO" in labels or "I-STEREO" in labels:
styles.append("border-bottom: 2px solid blue;")
if "B-GEN" in labels or "I-GEN" in labels:
styles.append("background-color: green; color: white;")
if "B-UNFAIR" in labels or "I-UNFAIR" in labels:
styles.append("border: 2px dashed red;")
style_string = " ".join(styles) if styles else ""
formatted_output += f"<span style='{style_string} padding: 3px; margin: 2px;'>{token}</span> "
formatted_output += "</div>"
return formatted_output
iface = gr.Interface(
fn=predict_ner_tags,
inputs="text",
outputs="html", # Directly use "html" here
title="Named Entity Recognition with BERT",
description="Enter a sentence to predict NER tags using a BERT model trained for multi-label classification. Different styles represent different entity types.",
examples=["Tall men are so clumsy."],
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()