maximuspowers's picture
Update app.py
34ab835 verified
raw
history blame
6.15 kB
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr
# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Mapping IDs to labels
id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
label2id = {v: k for k, v in id2label.items()}
# Entity colors for highlights
label_colors = {
"STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
"GEN": "rgba(0, 0, 255, 0.2)", # Light Blue
"UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green
}
# Post-process entity tags
def post_process_entities(result):
prev_entity_type = None
for token_data in result:
labels = token_data["labels"]
labels = list(set(labels))
# Handle conflicting B- and I- tags for the same entity
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
labels.remove(f"I-{entity_type}")
# Handle sequence rules
current_entity_type = None
current_label = None
for label in labels:
if label.startswith("B-") or label.startswith("I-"):
current_label = label
current_entity_type = label[2:]
if current_entity_type:
if current_label.startswith("B-") and prev_entity_type == current_entity_type:
labels.remove(current_label)
labels.append(f"I-{current_entity_type}")
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
labels.remove(current_label)
labels.append(f"B-{current_entity_type}")
prev_entity_type = current_entity_type
else:
prev_entity_type = None
token_data["labels"] = labels
return result
# Generate JSON results
def generate_json(sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
result = []
for i, token in enumerate(tokens):
if token not in tokenizer.all_special_tokens:
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
result.append({"token": token.replace("##", ""), "labels": labels})
result = post_process_entities(result)
return json.dumps(result, indent=4)
# Predict function
def predict_ner_tags_with_json(sentence):
json_result = generate_json(sentence)
result = json.loads(json_result)
word_row = []
stereo_row = []
gen_row = []
unfair_row = []
for token_data in result:
token = token_data["token"]
labels = token_data["labels"]
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
stereo_labels = [label[2:] for label in labels if "STEREO" in label]
stereo_row.append(
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
if stereo_labels else "&nbsp;"
)
gen_labels = [label[2:] for label in labels if "GEN" in label]
gen_row.append(
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
if gen_labels else "&nbsp;"
)
unfair_labels = [label[2:] for label in labels if "UNFAIR" in label]
unfair_row.append(
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
if unfair_labels else "&nbsp;"
)
matrix_html = f"""
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
<tr>
<td><strong>Text Sequence</strong></td>
{''.join(f"<td>{word}</td>" for word in word_row)}
</tr>
<tr>
<td><strong>Generalizations</strong></td>
{''.join(f"<td>{cell}</td>" for cell in gen_row)}
</tr>
<tr>
<td><strong>Unfairness</strong></td>
{''.join(f"<td>{cell}</td>" for cell in unfair_row)}
</tr>
<tr>
<td><strong>Stereotypes</strong></td>
{''.join(f"<td>{cell}</td>" for cell in stereo_row)}
</tr>
</table>
"""
return f"{matrix_html}<br><pre>{json_result}</pre>"
# Gradio Interface
iface = gr.Blocks()
with iface:
with gr.Row():
gr.Markdown(
"""
# Social Bias Named Entity Recognition (with BERT) 🕵
Enter a sentence to predict biased parts of speech tags. This model uses multi-label `BertForTokenClassification` to label the entities:
- **Generalizations (GEN)**
- **Unfairness (UNFAIR)**
- **Stereotypes (STEREO)**
Labels follow the BIO format. Try it out!
- **[Blog Post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition)**
- **[Model Page](https://huggingface.co/maximuspowers/bias-detection-ner)**
"""
)
with gr.Row():
input_box = gr.Textbox(label="Input Sentence")
with gr.Row():
output_box = gr.HTML(label="Entity Matrix and JSON Output")
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
iface.launch(share=True)