import torch from transformers import BertTokenizerFast, BertForTokenClassification import gradio as gr # Load tokenizer and model tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner') model.eval() # Set the model to evaluation mode model.to('cuda' if torch.cuda.is_available() else 'cpu') # Move model to appropriate device # Define label mappings id2label = { 0: 'O', 1: 'B-STEREO', 2: 'I-STEREO', 3: 'B-GEN', 4: 'I-GEN', 5: 'B-UNFAIR', 6: 'I-UNFAIR' } def predict_ner_tags(sentence): # Tokenize the sentence inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) input_ids = inputs['input_ids'].to(model.device) attention_mask = inputs['attention_mask'].to(model.device) # Predict using the model with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.sigmoid(logits) predicted_labels = (probabilities > 0.5).int() result = [] tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) for i, token in enumerate(tokens): if token not in tokenizer.all_special_tokens: label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1) labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O'] result.append((token, labels)) return result def format_output(result): # Create HTML content with formatted output formatted_output = "
" for token, labels in result: styles = [] if "B-STEREO" in labels or "I-STEREO" in labels: styles.append("border-bottom: 2px solid blue;") if "B-GEN" in labels or "I-GEN" in labels: styles.append("background-color: green; color: white;") if "B-UNFAIR" in labels or "I-UNFAIR" in labels: styles.append("border: 2px dashed red;") style_string = " ".join(styles) if styles else "" formatted_output += f"{token} " formatted_output += "
" return formatted_output iface = gr.Interface( fn=predict_ner_tags, inputs="text", outputs=gr.outputs.HTML(label="Output"), title="Named Entity Recognition with BERT", description="Enter a sentence to predict NER tags using a BERT model trained for multi-label classification. Different styles represent different entity types.", examples=["Tall men are so clumsy."], allow_flagging="never", theme="default" ) if __name__ == "__main__": iface.launch()