Spaces:
Running
on
Zero
Running
on
Zero
threshold reduced, more aggressive tagger
Browse files- agent.py +1 -1
- instruction_classifier.py +177 -87
- utils.py +20 -3
agent.py
CHANGED
@@ -549,7 +549,7 @@ Body: {email.body_value}"""
|
|
549 |
|
550 |
|
551 |
# Import the instruction classifier sanitizer
|
552 |
-
from instruction_classifier import
|
553 |
|
554 |
|
555 |
def extract_tool_calls(text):
|
|
|
549 |
|
550 |
|
551 |
# Import the instruction classifier sanitizer
|
552 |
+
from instruction_classifier import sanitize_tool_output_with_annotations
|
553 |
|
554 |
|
555 |
def extract_tool_calls(text):
|
instruction_classifier.py
CHANGED
@@ -51,6 +51,7 @@ class InstructionClassifierSanitizer:
|
|
51 |
model_filename: str = "best_instruction_classifier.pth",
|
52 |
model_name: str = "xlm-roberta-base",
|
53 |
threshold: float = 0.01,
|
|
|
54 |
max_length: int = 512,
|
55 |
overlap: int = 256,
|
56 |
use_local_model: bool = False # Set to False to use HF Hub
|
@@ -63,13 +64,15 @@ class InstructionClassifierSanitizer:
|
|
63 |
model_repo_id: Hugging Face model repository ID (if use_local_model=False)
|
64 |
model_filename: Filename of the model in the HF repository
|
65 |
model_name: Base transformer model name
|
66 |
-
threshold:
|
|
|
67 |
max_length: Maximum sequence length for sliding windows
|
68 |
overlap: Overlap between sliding windows
|
69 |
use_local_model: Whether to use local model file or download from HF Hub
|
70 |
"""
|
71 |
self.model_name = model_name
|
72 |
self.threshold = threshold
|
|
|
73 |
self.max_length = max_length
|
74 |
self.overlap = overlap
|
75 |
self.use_local_model = use_local_model
|
@@ -181,54 +184,6 @@ class InstructionClassifierSanitizer:
|
|
181 |
self.model.to(self.device) # Keep on CPU during initialization
|
182 |
self.model.eval()
|
183 |
|
184 |
-
@spaces.GPU
|
185 |
-
def sanitize_tool_output(self, tool_output: str) -> str:
|
186 |
-
"""
|
187 |
-
Main sanitization function that processes tool output and removes instruction content
|
188 |
-
|
189 |
-
Args:
|
190 |
-
tool_output: The raw tool output string
|
191 |
-
|
192 |
-
Returns:
|
193 |
-
Sanitized tool output with instruction content removed
|
194 |
-
"""
|
195 |
-
if not tool_output or not tool_output.strip():
|
196 |
-
return tool_output
|
197 |
-
|
198 |
-
# Move model to target device (GPU) within @spaces.GPU decorated method
|
199 |
-
if self.device != self.target_device:
|
200 |
-
print(f"🚀 Moving model from {self.device} to {self.target_device} within @spaces.GPU context")
|
201 |
-
self.model.to(self.target_device)
|
202 |
-
self.device = self.target_device
|
203 |
-
|
204 |
-
try:
|
205 |
-
# Step 1: Detect if the tool output contains instructions
|
206 |
-
is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
|
207 |
-
|
208 |
-
print(f"🔍 Instruction detection: injection={is_injection}, confidence={confidence_score:.3f}")
|
209 |
-
|
210 |
-
if not is_injection:
|
211 |
-
print("✅ No injection detected - returning original output")
|
212 |
-
return tool_output
|
213 |
-
|
214 |
-
print(f"🚨 Injection detected! Sanitizing output...")
|
215 |
-
print(f" Original: {tool_output}")
|
216 |
-
print(f" Tagged: {tagged_text}")
|
217 |
-
|
218 |
-
# Step 2: Merge close instruction tags
|
219 |
-
merged_tagged_text = self._merge_close_instruction_tags(tagged_text, min_words_between=4)
|
220 |
-
print(f" After merging: {merged_tagged_text}")
|
221 |
-
|
222 |
-
# Step 3: Remove instruction tags and their content
|
223 |
-
sanitized_output = self._remove_instruction_tags(merged_tagged_text)
|
224 |
-
print(f" Sanitized: {sanitized_output}")
|
225 |
-
|
226 |
-
return sanitized_output
|
227 |
-
|
228 |
-
except Exception as e:
|
229 |
-
print(f"❌ Error in instruction classifier sanitization: {e}")
|
230 |
-
# Return original output if sanitization fails
|
231 |
-
return tool_output
|
232 |
|
233 |
@spaces.GPU
|
234 |
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
|
@@ -261,17 +216,23 @@ class InstructionClassifierSanitizer:
|
|
261 |
print("✅ No injection detected - returning original output")
|
262 |
return tool_output, []
|
263 |
|
264 |
-
print(f"🚨 Injection detected!
|
265 |
|
266 |
-
# Step 2: Extract annotation positions from tagged text
|
267 |
annotations = self._extract_annotations_from_tagged_text(tagged_text, tool_output)
|
|
|
268 |
|
269 |
-
# Step 3:
|
270 |
-
|
|
|
271 |
|
272 |
-
# Step 4:
|
273 |
-
|
|
|
274 |
|
|
|
|
|
|
|
275 |
return sanitized_output, annotations
|
276 |
|
277 |
except Exception as e:
|
@@ -367,11 +328,12 @@ class InstructionClassifierSanitizer:
|
|
367 |
from utils import predict_instructions
|
368 |
|
369 |
try:
|
370 |
-
# Use the predict_instructions function directly
|
371 |
-
tokens, predictions = predict_instructions(self.model, self.tokenizer, text, self.device)
|
372 |
return predictions, tokens
|
373 |
except Exception as e:
|
374 |
-
print(f"Error in predict_instructions: {e}")
|
|
|
375 |
# Fallback to simple tokenization if the complex method fails
|
376 |
return self._simple_predict(text)
|
377 |
|
@@ -402,7 +364,9 @@ class InstructionClassifierSanitizer:
|
|
402 |
self.model.eval()
|
403 |
with torch.no_grad():
|
404 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
405 |
-
|
|
|
|
|
406 |
|
407 |
# Convert back to word-level predictions
|
408 |
word_ids = encoded.word_ids()
|
@@ -494,6 +458,160 @@ class InstructionClassifierSanitizer:
|
|
494 |
|
495 |
return result
|
496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
def _merge_close_instruction_tags(self, text, min_words_between=3):
|
498 |
"""
|
499 |
Merge <instruction>...</instruction> segments that are separated by less than min_words_between words
|
@@ -572,34 +690,6 @@ def get_sanitizer():
|
|
572 |
return None
|
573 |
return _sanitizer_instance
|
574 |
|
575 |
-
def sanitize_tool_output(tool_output, defense_enabled=True):
|
576 |
-
"""
|
577 |
-
Main sanitization function that uses the instruction classifier to detect and remove
|
578 |
-
prompt injection attempts from tool outputs.
|
579 |
-
|
580 |
-
Args:
|
581 |
-
tool_output: The raw tool output string
|
582 |
-
defense_enabled: Whether defense is enabled (passed from agent)
|
583 |
-
|
584 |
-
Returns:
|
585 |
-
Sanitized tool output with instruction content removed
|
586 |
-
"""
|
587 |
-
print(f"🔍 sanitize_tool_output called with: {tool_output[:100]}...")
|
588 |
-
|
589 |
-
# If defense is disabled globally, return original output
|
590 |
-
if not defense_enabled:
|
591 |
-
print("⚠️ Defense disabled - returning original output without processing")
|
592 |
-
return tool_output
|
593 |
-
|
594 |
-
sanitizer = get_sanitizer()
|
595 |
-
if sanitizer is None:
|
596 |
-
print("⚠️ Instruction classifier not available, returning original output")
|
597 |
-
return tool_output
|
598 |
-
|
599 |
-
print("✅ Sanitizer found, processing...")
|
600 |
-
result = sanitizer.sanitize_tool_output(tool_output)
|
601 |
-
print(f"🔒 Sanitization complete, result: {result[:100]}...")
|
602 |
-
return result
|
603 |
|
604 |
def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
|
605 |
"""
|
|
|
51 |
model_filename: str = "best_instruction_classifier.pth",
|
52 |
model_name: str = "xlm-roberta-base",
|
53 |
threshold: float = 0.01,
|
54 |
+
token_threshold: float = 0.4,
|
55 |
max_length: int = 512,
|
56 |
overlap: int = 256,
|
57 |
use_local_model: bool = False # Set to False to use HF Hub
|
|
|
64 |
model_repo_id: Hugging Face model repository ID (if use_local_model=False)
|
65 |
model_filename: Filename of the model in the HF repository
|
66 |
model_name: Base transformer model name
|
67 |
+
threshold: Document-level threshold - proportion of tokens that must be INSTRUCTION to classify document as injection
|
68 |
+
token_threshold: Token-level threshold - probability threshold for classifying individual tokens as INSTRUCTION (0.0-1.0, lower = more aggressive)
|
69 |
max_length: Maximum sequence length for sliding windows
|
70 |
overlap: Overlap between sliding windows
|
71 |
use_local_model: Whether to use local model file or download from HF Hub
|
72 |
"""
|
73 |
self.model_name = model_name
|
74 |
self.threshold = threshold
|
75 |
+
self.token_threshold = token_threshold
|
76 |
self.max_length = max_length
|
77 |
self.overlap = overlap
|
78 |
self.use_local_model = use_local_model
|
|
|
184 |
self.model.to(self.device) # Keep on CPU during initialization
|
185 |
self.model.eval()
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
@spaces.GPU
|
189 |
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
|
|
|
216 |
print("✅ No injection detected - returning original output")
|
217 |
return tool_output, []
|
218 |
|
219 |
+
print(f"🚨 Injection detected! Processing with extensions and annotations...")
|
220 |
|
221 |
+
# Step 2: Extract annotation positions from original tagged text
|
222 |
annotations = self._extract_annotations_from_tagged_text(tagged_text, tool_output)
|
223 |
+
print(f"📝 Original tagged text: {tagged_text}")
|
224 |
|
225 |
+
# Step 3: Extend instruction tags by one token on each side
|
226 |
+
extended_tagged_text = self._extend_instruction_tags(tagged_text)
|
227 |
+
print(f"🔄 Extended tagged text: {extended_tagged_text}")
|
228 |
|
229 |
+
# Step 4: Merge close instruction tags
|
230 |
+
merged_tagged_text = self._merge_close_instruction_tags(extended_tagged_text, min_words_between=4)
|
231 |
+
print(f"🔗 Merged tagged text: {merged_tagged_text}")
|
232 |
|
233 |
+
# Step 5: Remove instruction tags and their content
|
234 |
+
sanitized_output = self._remove_instruction_tags(merged_tagged_text)
|
235 |
+
print(f"🔒 Sanitized output: {sanitized_output}")
|
236 |
return sanitized_output, annotations
|
237 |
|
238 |
except Exception as e:
|
|
|
328 |
from utils import predict_instructions
|
329 |
|
330 |
try:
|
331 |
+
# Use the predict_instructions function directly with token-level threshold
|
332 |
+
tokens, predictions = predict_instructions(self.model, self.tokenizer, text, self.device, self.token_threshold)
|
333 |
return predictions, tokens
|
334 |
except Exception as e:
|
335 |
+
print(f"⚠️ FALLBACK TRIGGERED: Error in predict_instructions: {e}")
|
336 |
+
print(f" Using _simple_predict as fallback (still uses threshold={self.token_threshold})")
|
337 |
# Fallback to simple tokenization if the complex method fails
|
338 |
return self._simple_predict(text)
|
339 |
|
|
|
364 |
self.model.eval()
|
365 |
with torch.no_grad():
|
366 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
367 |
+
# Use threshold approach (same as main prediction) instead of argmax
|
368 |
+
probs = torch.softmax(outputs['logits'], dim=-1)
|
369 |
+
predictions = (probs[:, :, 1] > self.token_threshold).long()
|
370 |
|
371 |
# Convert back to word-level predictions
|
372 |
word_ids = encoded.word_ids()
|
|
|
458 |
|
459 |
return result
|
460 |
|
461 |
+
def _extend_instruction_tags(self, tagged_text: str) -> str:
|
462 |
+
"""
|
463 |
+
Extend each <instruction>...</instruction> block by one word token on each side,
|
464 |
+
but only if the adjacent token is not already instruction-tagged.
|
465 |
+
|
466 |
+
This prevents overlaps between instruction blocks while extending them
|
467 |
+
to capture more context around detected instruction content.
|
468 |
+
|
469 |
+
Args:
|
470 |
+
tagged_text: Text with <instruction>...</instruction> tags
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
Text with extended instruction tags
|
474 |
+
"""
|
475 |
+
if not tagged_text.strip():
|
476 |
+
return tagged_text
|
477 |
+
|
478 |
+
# Find all instruction regions first to avoid overlaps
|
479 |
+
instruction_regions = []
|
480 |
+
pattern = r'<instruction>(.*?)</instruction>'
|
481 |
+
|
482 |
+
for match in re.finditer(pattern, tagged_text, re.DOTALL):
|
483 |
+
instruction_regions.append({
|
484 |
+
'start': match.start(),
|
485 |
+
'end': match.end(),
|
486 |
+
'content': match.group(1)
|
487 |
+
})
|
488 |
+
|
489 |
+
if not instruction_regions:
|
490 |
+
return tagged_text
|
491 |
+
|
492 |
+
# Split into words while preserving positions
|
493 |
+
words = tagged_text.split()
|
494 |
+
|
495 |
+
# Build word-to-character position mapping
|
496 |
+
word_positions = []
|
497 |
+
char_pos = 0
|
498 |
+
for word in words:
|
499 |
+
start_pos = tagged_text.find(word, char_pos)
|
500 |
+
end_pos = start_pos + len(word)
|
501 |
+
word_positions.append({
|
502 |
+
'word': word,
|
503 |
+
'start': start_pos,
|
504 |
+
'end': end_pos
|
505 |
+
})
|
506 |
+
char_pos = end_pos
|
507 |
+
|
508 |
+
# Find which words are currently inside instruction tags
|
509 |
+
instruction_word_indices = set()
|
510 |
+
|
511 |
+
for region in instruction_regions:
|
512 |
+
for i, word_info in enumerate(word_positions):
|
513 |
+
# Check if word overlaps with instruction region
|
514 |
+
if (word_info['start'] < region['end'] and word_info['end'] > region['start']):
|
515 |
+
instruction_word_indices.add(i)
|
516 |
+
|
517 |
+
# Find instruction blocks by consecutive instruction words
|
518 |
+
instruction_blocks = []
|
519 |
+
current_block = None
|
520 |
+
|
521 |
+
for i in range(len(words)):
|
522 |
+
if i in instruction_word_indices:
|
523 |
+
if current_block is None:
|
524 |
+
current_block = {'start': i, 'end': i}
|
525 |
+
else:
|
526 |
+
current_block['end'] = i
|
527 |
+
else:
|
528 |
+
if current_block is not None:
|
529 |
+
instruction_blocks.append(current_block)
|
530 |
+
current_block = None
|
531 |
+
|
532 |
+
# Don't forget the last block
|
533 |
+
if current_block is not None:
|
534 |
+
instruction_blocks.append(current_block)
|
535 |
+
|
536 |
+
# Plan extensions with proper overlap prevention
|
537 |
+
extensions = []
|
538 |
+
planned_tagged_words = set(instruction_word_indices) # Start with currently tagged words
|
539 |
+
|
540 |
+
for block in instruction_blocks:
|
541 |
+
start_idx = block['start']
|
542 |
+
end_idx = block['end']
|
543 |
+
|
544 |
+
extend_left = False
|
545 |
+
extend_right = False
|
546 |
+
|
547 |
+
# Try extend left (if not at beginning and previous token not planned to be tagged)
|
548 |
+
if start_idx > 0 and (start_idx - 1) not in planned_tagged_words:
|
549 |
+
extend_left = True
|
550 |
+
planned_tagged_words.add(start_idx - 1) # Reserve this word
|
551 |
+
|
552 |
+
# Try extend right (if not at end and next token not planned to be tagged)
|
553 |
+
if end_idx < len(words) - 1 and (end_idx + 1) not in planned_tagged_words:
|
554 |
+
extend_right = True
|
555 |
+
planned_tagged_words.add(end_idx + 1) # Reserve this word
|
556 |
+
|
557 |
+
extensions.append({
|
558 |
+
'original_start': start_idx,
|
559 |
+
'original_end': end_idx,
|
560 |
+
'new_start': start_idx - (1 if extend_left else 0),
|
561 |
+
'new_end': end_idx + (1 if extend_right else 0),
|
562 |
+
'extend_left': extend_left,
|
563 |
+
'extend_right': extend_right
|
564 |
+
})
|
565 |
+
|
566 |
+
# Create a mapping of which extension block each word belongs to
|
567 |
+
word_to_block = {}
|
568 |
+
for block_idx, ext in enumerate(extensions):
|
569 |
+
for i in range(ext['new_start'], ext['new_end'] + 1):
|
570 |
+
word_to_block[i] = block_idx
|
571 |
+
|
572 |
+
# Reconstruct the text with separate instruction blocks
|
573 |
+
result_parts = []
|
574 |
+
current_block = None
|
575 |
+
|
576 |
+
for i, word in enumerate(words):
|
577 |
+
# Clean the word of existing tags
|
578 |
+
clean_word = word.replace('<instruction>', '').replace('</instruction>', '')
|
579 |
+
|
580 |
+
# Skip empty words (from empty instruction tags)
|
581 |
+
if not clean_word.strip():
|
582 |
+
continue
|
583 |
+
|
584 |
+
word_block = word_to_block.get(i, None)
|
585 |
+
|
586 |
+
if word_block is not None and current_block != word_block:
|
587 |
+
# Close previous block if needed
|
588 |
+
if current_block is not None:
|
589 |
+
result_parts[-1] += '</instruction>'
|
590 |
+
|
591 |
+
# Start new instruction block
|
592 |
+
result_parts.append(f'<instruction>{clean_word}')
|
593 |
+
current_block = word_block
|
594 |
+
|
595 |
+
elif word_block is not None and current_block == word_block:
|
596 |
+
# Continue current instruction block
|
597 |
+
result_parts.append(clean_word)
|
598 |
+
|
599 |
+
elif word_block is None and current_block is not None:
|
600 |
+
# End instruction block and add normal word
|
601 |
+
result_parts[-1] += '</instruction>'
|
602 |
+
result_parts.append(clean_word)
|
603 |
+
current_block = None
|
604 |
+
|
605 |
+
else:
|
606 |
+
# Normal word (not in any instruction block)
|
607 |
+
result_parts.append(clean_word)
|
608 |
+
|
609 |
+
# Close instruction if we ended inside one
|
610 |
+
if current_block is not None:
|
611 |
+
result_parts[-1] += '</instruction>'
|
612 |
+
|
613 |
+
return ' '.join(result_parts)
|
614 |
+
|
615 |
def _merge_close_instruction_tags(self, text, min_words_between=3):
|
616 |
"""
|
617 |
Merge <instruction>...</instruction> segments that are separated by less than min_words_between words
|
|
|
690 |
return None
|
691 |
return _sanitizer_instance
|
692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
|
694 |
def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
|
695 |
"""
|
utils.py
CHANGED
@@ -581,8 +581,20 @@ def collate_fn(batch):
|
|
581 |
'window_ends': [item['window_end'] for item in batch]
|
582 |
}
|
583 |
|
584 |
-
def predict_instructions(model, tokenizer, text: str, device=None):
|
585 |
-
"""Predict instructions in a given text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
# Auto-detect device if not provided
|
587 |
if device is None:
|
588 |
if torch.backends.mps.is_available():
|
@@ -613,7 +625,12 @@ def predict_instructions(model, tokenizer, text: str, device=None):
|
|
613 |
|
614 |
with torch.no_grad():
|
615 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
616 |
-
|
|
|
|
|
|
|
|
|
|
|
617 |
|
618 |
# Align predictions with original tokens
|
619 |
word_ids = encoded.word_ids()
|
|
|
581 |
'window_ends': [item['window_end'] for item in batch]
|
582 |
}
|
583 |
|
584 |
+
def predict_instructions(model, tokenizer, text: str, device=None, threshold=0.4):
|
585 |
+
"""Predict instructions in a given text
|
586 |
+
|
587 |
+
Args:
|
588 |
+
model: The trained instruction classifier model
|
589 |
+
tokenizer: The tokenizer for the model
|
590 |
+
text: Input text to analyze
|
591 |
+
device: Device to run inference on
|
592 |
+
threshold: Probability threshold for classifying tokens as INSTRUCTION.
|
593 |
+
Lower values = more aggressive detection (default: 0.4)
|
594 |
+
|
595 |
+
Returns:
|
596 |
+
tuple: (tokens, predictions) where predictions are 0=OTHER, 1=INSTRUCTION
|
597 |
+
"""
|
598 |
# Auto-detect device if not provided
|
599 |
if device is None:
|
600 |
if torch.backends.mps.is_available():
|
|
|
625 |
|
626 |
with torch.no_grad():
|
627 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
628 |
+
# Convert logits to probabilities
|
629 |
+
probs = torch.softmax(outputs['logits'], dim=-1)
|
630 |
+
# Use threshold on probability of class 1 (INSTRUCTION) instead of argmax
|
631 |
+
# This makes the classifier more aggressive - tokens are classified as INSTRUCTION
|
632 |
+
# if their probability of being INSTRUCTION is above the threshold
|
633 |
+
predictions = (probs[:, :, 1] > threshold).long()
|
634 |
|
635 |
# Align predictions with original tokens
|
636 |
word_ids = encoded.word_ids()
|