Spaces:
Sleeping
Sleeping
import torch | |
from transformers import BertTokenizerFast, BertForTokenClassification | |
import gradio as gr | |
# Load tokenizer and model | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner') | |
model.eval() | |
model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
# Define label mappings | |
id2label = { | |
0: 'O', | |
1: 'B-STEREO', | |
2: 'I-STEREO', | |
3: 'B-GEN', | |
4: 'I-GEN', | |
5: 'B-UNFAIR', | |
6: 'I-UNFAIR' | |
} | |
def predict_ner_tags(sentence): | |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
input_ids = inputs['input_ids'].to(model.device) | |
attention_mask = inputs['attention_mask'].to(model.device) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
probabilities = torch.sigmoid(logits) | |
predicted_labels = (probabilities > 0.5).int() | |
result = [] | |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
for i, token in enumerate(tokens): | |
if token not in tokenizer.all_special_tokens: | |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1) | |
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O'] | |
result.append((token, labels)) | |
return result | |
def format_output(result): | |
formatted_output = "" | |
for token, labels in result: | |
formatted_output += f"{token}: {', '.join(labels)}\n" | |
return formatted_output | |
iface = gr.Interface( | |
fn=predict_ner_tags, | |
inputs="text", | |
outputs="text", | |
title="Named Entity Recognition with BERT", | |
description="Enter a sentence to predict NER tags using BERT model trained for multi-label classification.", | |
examples=["Tall men are so clumsy."], | |
allow_flagging="never", | |
interpretation="default", | |
postprocessing_fn=format_output | |
) | |
if __name__ == "__main__": | |
iface.launch() | |