maximuspowers commited on
Commit
939c704
·
verified ·
1 Parent(s): 5dfca2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -14
app.py CHANGED
@@ -22,15 +22,20 @@ id2label = {
22
 
23
  # Color map for entities
24
  label_colors = {
25
- "B-STEREO": "#FFCDD2",
26
- "I-STEREO": "#E57373",
27
- "B-GEN": "#C8E6C9",
28
- "I-GEN": "#81C784",
29
- "B-UNFAIR": "#BBDEFB",
30
- "I-UNFAIR": "#64B5F6",
31
- "O": "#FFFFFF" # Default for no label
32
  }
33
 
 
 
 
 
 
 
 
 
 
34
  # Predict function
35
  def predict_ner_tags(sentence):
36
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
@@ -43,19 +48,34 @@ def predict_ner_tags(sentence):
43
  probabilities = torch.sigmoid(logits)
44
  predicted_labels = (probabilities > 0.5).int() # Threshold
45
 
46
- result = []
47
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
48
  highlighted_sentence = ""
 
 
49
  for i, token in enumerate(tokens):
50
  if token not in tokenizer.all_special_tokens:
 
51
  label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
52
- labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
53
- # Get the most prominent label for coloring (arbitrary choice for multiple labels)
54
- primary_label = labels[0] if labels else "O"
55
- color = label_colors.get(primary_label, "#FFFFFF")
56
- highlighted_sentence += f"<span style='background-color:{color}'>{token}</span> "
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- return highlighted_sentence.strip()
59
 
60
  # Gradio Interface
61
  iface = gr.Interface(
 
22
 
23
  # Color map for entities
24
  label_colors = {
25
+ "STEREO": "rgba(255, 0, 0, 0.3)", # Red
26
+ "GEN": "rgba(0, 0, 255, 0.3)", # Blue
27
+ "UNFAIR": "rgba(0, 255, 0, 0.3)" # Green
 
 
 
 
28
  }
29
 
30
+ # Helper to wrap a token in a span with color
31
+ def wrap_token_with_color(token, labels):
32
+ # Build nested highlights
33
+ style = "position: relative;"
34
+ for label in labels:
35
+ if label != "O":
36
+ style += f"background: {label_colors[label]};"
37
+ return f"<span style='{style}'>{token}</span>"
38
+
39
  # Predict function
40
  def predict_ner_tags(sentence):
41
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
 
48
  probabilities = torch.sigmoid(logits)
49
  predicted_labels = (probabilities > 0.5).int() # Threshold
50
 
 
51
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
52
  highlighted_sentence = ""
53
+ prev_labels = []
54
+
55
  for i, token in enumerate(tokens):
56
  if token not in tokenizer.all_special_tokens:
57
+ # Extract the labels for this token
58
  label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
59
+ labels = [id2label[idx.item()][2:] for idx in label_indices] if label_indices.numel() > 0 else ['O']
60
+
61
+ # Check if labels are the same as the previous token (for seamless highlighting)
62
+ if labels != prev_labels:
63
+ if prev_labels: # Close the previous span if needed
64
+ highlighted_sentence += "</span>"
65
+
66
+ # Start a new span
67
+ if labels != ["O"]:
68
+ highlighted_sentence += f"<span style='background: linear-gradient({', '.join([label_colors[label] for label in labels])})'>"
69
+
70
+ # Add the token to the span
71
+ highlighted_sentence += token.replace("##", "")
72
+ prev_labels = labels
73
+
74
+ # Close any open spans
75
+ if prev_labels and prev_labels != ["O"]:
76
+ highlighted_sentence += "</span>"
77
 
78
+ return highlighted_sentence
79
 
80
  # Gradio Interface
81
  iface = gr.Interface(