Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,15 +22,20 @@ id2label = {
|
|
22 |
|
23 |
# Color map for entities
|
24 |
label_colors = {
|
25 |
-
"
|
26 |
-
"
|
27 |
-
"
|
28 |
-
"I-GEN": "#81C784",
|
29 |
-
"B-UNFAIR": "#BBDEFB",
|
30 |
-
"I-UNFAIR": "#64B5F6",
|
31 |
-
"O": "#FFFFFF" # Default for no label
|
32 |
}
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Predict function
|
35 |
def predict_ner_tags(sentence):
|
36 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
@@ -43,19 +48,34 @@ def predict_ner_tags(sentence):
|
|
43 |
probabilities = torch.sigmoid(logits)
|
44 |
predicted_labels = (probabilities > 0.5).int() # Threshold
|
45 |
|
46 |
-
result = []
|
47 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
48 |
highlighted_sentence = ""
|
|
|
|
|
49 |
for i, token in enumerate(tokens):
|
50 |
if token not in tokenizer.all_special_tokens:
|
|
|
51 |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
|
52 |
-
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
return highlighted_sentence
|
59 |
|
60 |
# Gradio Interface
|
61 |
iface = gr.Interface(
|
|
|
22 |
|
23 |
# Color map for entities
|
24 |
label_colors = {
|
25 |
+
"STEREO": "rgba(255, 0, 0, 0.3)", # Red
|
26 |
+
"GEN": "rgba(0, 0, 255, 0.3)", # Blue
|
27 |
+
"UNFAIR": "rgba(0, 255, 0, 0.3)" # Green
|
|
|
|
|
|
|
|
|
28 |
}
|
29 |
|
30 |
+
# Helper to wrap a token in a span with color
|
31 |
+
def wrap_token_with_color(token, labels):
|
32 |
+
# Build nested highlights
|
33 |
+
style = "position: relative;"
|
34 |
+
for label in labels:
|
35 |
+
if label != "O":
|
36 |
+
style += f"background: {label_colors[label]};"
|
37 |
+
return f"<span style='{style}'>{token}</span>"
|
38 |
+
|
39 |
# Predict function
|
40 |
def predict_ner_tags(sentence):
|
41 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
|
|
48 |
probabilities = torch.sigmoid(logits)
|
49 |
predicted_labels = (probabilities > 0.5).int() # Threshold
|
50 |
|
|
|
51 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
52 |
highlighted_sentence = ""
|
53 |
+
prev_labels = []
|
54 |
+
|
55 |
for i, token in enumerate(tokens):
|
56 |
if token not in tokenizer.all_special_tokens:
|
57 |
+
# Extract the labels for this token
|
58 |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
|
59 |
+
labels = [id2label[idx.item()][2:] for idx in label_indices] if label_indices.numel() > 0 else ['O']
|
60 |
+
|
61 |
+
# Check if labels are the same as the previous token (for seamless highlighting)
|
62 |
+
if labels != prev_labels:
|
63 |
+
if prev_labels: # Close the previous span if needed
|
64 |
+
highlighted_sentence += "</span>"
|
65 |
+
|
66 |
+
# Start a new span
|
67 |
+
if labels != ["O"]:
|
68 |
+
highlighted_sentence += f"<span style='background: linear-gradient({', '.join([label_colors[label] for label in labels])})'>"
|
69 |
+
|
70 |
+
# Add the token to the span
|
71 |
+
highlighted_sentence += token.replace("##", "")
|
72 |
+
prev_labels = labels
|
73 |
+
|
74 |
+
# Close any open spans
|
75 |
+
if prev_labels and prev_labels != ["O"]:
|
76 |
+
highlighted_sentence += "</span>"
|
77 |
|
78 |
+
return highlighted_sentence
|
79 |
|
80 |
# Gradio Interface
|
81 |
iface = gr.Interface(
|