ddas commited on
Commit
40187f3
·
unverified ·
1 Parent(s): e2aa9a2

UI tagged preview added

Browse files
Files changed (4) hide show
  1. agent.py +67 -28
  2. app.py +107 -16
  3. instruction_classifier.py +18 -13
  4. utils.py +3 -133
agent.py CHANGED
@@ -551,26 +551,6 @@ Body: {email.body_value}"""
551
  # Import the instruction classifier sanitizer
552
  from instruction_classifier import sanitize_tool_output_with_annotations
553
 
554
-
555
- def extract_tool_calls(text):
556
- """Extract tool calls from LLM output (legacy function - kept for compatibility)"""
557
- tool_calls = []
558
-
559
- # Patterns to match tool calls
560
- patterns = [
561
- r'get_emails\(\)',
562
- r'search_email\(keyword=[^)]*\)', # search_email(keyword="UBS")
563
- r'search_email\(\s*"[^"]+"\s*\)', # search_email("UBS")
564
- r'send_email\([^)]+\)'
565
- ]
566
-
567
- for pattern in patterns:
568
- matches = re.findall(pattern, text)
569
- tool_calls.extend(matches)
570
-
571
- return tool_calls
572
-
573
-
574
  def extract_and_parse_tool_calls(text):
575
  """
576
  Extract tool calls from LLM output and parse them into structured format
@@ -714,6 +694,38 @@ def create_assistant_message_with_tool_calls(llm_output, parsed_tool_calls, prov
714
  return {"role": "assistant", "content": llm_output}
715
 
716
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
  def create_tool_result_message(tool_results, provider):
718
  """
719
  Create properly formatted tool result message based on LLM provider
@@ -843,6 +855,9 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
843
  # Track annotations for instruction classifier flagged content
844
  all_annotations = []
845
 
 
 
 
846
  # Initialize conversation with system prompt and user query
847
  # This will be used for LLM API calls (provider-specific format)
848
  llm_messages = [
@@ -968,7 +983,17 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
968
  # Conditional sanitization based on defense setting
969
  if defense_enabled:
970
  # Sanitize tool output with annotations
971
- sanitized_output, annotations = sanitize_tool_output_with_annotations(tool_output, defense_enabled)
 
 
 
 
 
 
 
 
 
 
972
 
973
  # Always add raw tool output to trace when defense is enabled
974
  raw_tool_message = {
@@ -988,17 +1013,26 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
988
 
989
  # Add annotations to our collection
990
  all_annotations.extend(annotations)
991
-
992
- annotation_msg = f"📝 Found {len(annotations)} instruction flags in tool output"
993
- execution_log.append(annotation_msg)
994
- print(annotation_msg)
 
 
995
 
996
- # Always show sanitized result in logs when defense is enabled
997
  sanitized_msg = f"🔒 Sanitized Result: {sanitized_output}"
998
  execution_log.append(sanitized_msg)
999
  print(sanitized_msg)
1000
 
1001
- # Always add sanitized tool output to trace when defense is enabled
 
 
 
 
 
 
 
1002
  sanitized_tool_message = {
1003
  "role": "tool",
1004
  "tool_call_id": tool_call_info['id'],
@@ -1144,4 +1178,9 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
1144
  final_trace_msg = f"📊 Trace push completed (with {len(all_annotations)} annotations)"
1145
  execution_log.append(final_trace_msg)
1146
 
1147
- return "\n".join(execution_log), llm_output
 
 
 
 
 
 
551
  # Import the instruction classifier sanitizer
552
  from instruction_classifier import sanitize_tool_output_with_annotations
553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  def extract_and_parse_tool_calls(text):
555
  """
556
  Extract tool calls from LLM output and parse them into structured format
 
694
  return {"role": "assistant", "content": llm_output}
695
 
696
 
697
+ def process_flagged_content(merged_tagged_text):
698
+ """
699
+ Process tagged text to find second 'To: [email protected]' and trim before it.
700
+
701
+ Args:
702
+ merged_tagged_text (str): Text with <instruction> tags from classifier
703
+
704
+ Returns:
705
+ str: Processed text with content before second occurrence removed
706
+ """
707
+ if not merged_tagged_text:
708
+ return ""
709
+
710
+ # Find second occurrence of "To: [email protected]"
711
+ target_phrase = "To: [email protected]"
712
+ first_pos = merged_tagged_text.find(target_phrase)
713
+ if first_pos != -1:
714
+ second_pos = merged_tagged_text.find(target_phrase, first_pos + 1)
715
+ if second_pos != -1:
716
+ # Remove everything before and including the second occurrence
717
+ processed_text = merged_tagged_text[second_pos + len(target_phrase):].strip()
718
+ print(f"🏷️ Found second occurrence at position {second_pos}, processed flagged content: {processed_text[:100]}...")
719
+ # Insert newline before "Time:" and "Body:" (but not if already at start of line)
720
+ processed_text = re.sub(r'(?<!\n)(Time:)', r'\n\1', processed_text)
721
+ processed_text = re.sub(r'(?<!\n)(Body:)', r'\n\1', processed_text)
722
+ return processed_text
723
+
724
+ # If no second occurrence, return entire text
725
+ print(f"🏷️ No second occurrence found, returning entire flagged content: {merged_tagged_text[:100]}...")
726
+ return merged_tagged_text
727
+
728
+
729
  def create_tool_result_message(tool_results, provider):
730
  """
731
  Create properly formatted tool result message based on LLM provider
 
855
  # Track annotations for instruction classifier flagged content
856
  all_annotations = []
857
 
858
+ # Track flagged content for UI display
859
+ all_flagged_content = []
860
+
861
  # Initialize conversation with system prompt and user query
862
  # This will be used for LLM API calls (provider-specific format)
863
  llm_messages = [
 
983
  # Conditional sanitization based on defense setting
984
  if defense_enabled:
985
  # Sanitize tool output with annotations
986
+ sanitized_output, annotations, merged_tagged_text = sanitize_tool_output_with_annotations(tool_output, defense_enabled)
987
+
988
+ # Process and collect flagged content for UI display
989
+ print(f"🔍 DEBUG: merged_tagged_text: {merged_tagged_text}")
990
+ print(f"🔍 DEBUG: has <instruction> tags: {'<instruction>' in merged_tagged_text if merged_tagged_text else 'No text'}")
991
+ if merged_tagged_text and merged_tagged_text.strip() and "<instruction>" in merged_tagged_text:
992
+ processed_flagged = process_flagged_content(merged_tagged_text)
993
+ print(f"🔍 DEBUG: processed_flagged result: {processed_flagged}")
994
+ if processed_flagged:
995
+ all_flagged_content.append(processed_flagged)
996
+ print(f"🔍 DEBUG: Added to all_flagged_content. Total items: {len(all_flagged_content)}")
997
 
998
  # Always add raw tool output to trace when defense is enabled
999
  raw_tool_message = {
 
1013
 
1014
  # Add annotations to our collection
1015
  all_annotations.extend(annotations)
1016
+
1017
+
1018
+ # Add some spacing before sanitized output for clarity
1019
+ execution_log.append("")
1020
+ execution_log.append("--- DEFENSE PROCESSING ---")
1021
+ execution_log.append("")
1022
 
1023
+ # Show sanitized result in logs when defense is enabled
1024
  sanitized_msg = f"🔒 Sanitized Result: {sanitized_output}"
1025
  execution_log.append(sanitized_msg)
1026
  print(sanitized_msg)
1027
 
1028
+ # Add spacing separator in trace for clarity
1029
+ separator_message = {
1030
+ "role": "system",
1031
+ "content": "--- DEFENSE SANITIZATION APPLIED ---"
1032
+ }
1033
+ trace_messages.append(separator_message)
1034
+
1035
+ # Add sanitized tool output to trace when defense is enabled
1036
  sanitized_tool_message = {
1037
  "role": "tool",
1038
  "tool_call_id": tool_call_info['id'],
 
1178
  final_trace_msg = f"📊 Trace push completed (with {len(all_annotations)} annotations)"
1179
  execution_log.append(final_trace_msg)
1180
 
1181
+ # Combine all flagged content for UI display
1182
+ combined_flagged_content = "\n\n".join(all_flagged_content) if all_flagged_content else ""
1183
+ print(f"🔍 DEBUG: Final combined_flagged_content: '{combined_flagged_content}'")
1184
+ print(f"🔍 DEBUG: Length: {len(combined_flagged_content)} characters")
1185
+
1186
+ return "\n".join(execution_log), llm_output, combined_flagged_content
app.py CHANGED
@@ -526,13 +526,6 @@ def is_likely_gibberish_soft(text):
526
 
527
  return False # Passes soft gibberish checks
528
 
529
- def validate_english_with_model_loading(text):
530
- """
531
- Convenience function that handles FastText model loading automatically.
532
- """
533
- model = load_fasttext_model() # This will download and load the model if needed
534
- return validate_english_only_windowed(text, model)
535
-
536
  def get_fasttext_confidence_scores(text, model=None, top_k=3):
537
  """
538
  Get top language confidence scores from FastText without doing validation.
@@ -761,7 +754,7 @@ def submit_attack(from_addr, attack_subject, attack_body, model_name="gpt-4o", d
761
  }
762
 
763
  # Process the fixed user query with the tool agent loop
764
- execution_log, final_output = tool_agent_loop(
765
  user_query=USER_INPUT,
766
  inbox=INBOX,
767
  system_prompt=SYSTEM_PROMPT,
@@ -771,13 +764,13 @@ def submit_attack(from_addr, attack_subject, attack_body, model_name="gpt-4o", d
771
  fasttext_confidence_scores=fasttext_confidence_scores
772
  )
773
 
774
- # Return execution log and final output separately
775
- return execution_log, final_output
776
 
777
  except Exception as e:
778
  error_msg = f"❌ Error processing attack: {str(e)}"
779
  print(error_msg)
780
- return "", error_msg
781
 
782
  def reset_to_initial_state():
783
  """Reset the inbox to original state and clear all inputs"""
@@ -1175,6 +1168,74 @@ def create_interface():
1175
  .results-card ul { margin: 0; padding-left: 16px; }
1176
  .results-card li { margin: 4px 0; }
1177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
  /* Error Modal Popup Styling */
1179
  .error-modal-overlay {
1180
  position: fixed !important;
@@ -1512,6 +1573,14 @@ Satya
1512
  )
1513
  # Attack results summary (pretty list)
1514
  results_display = gr.HTML("", elem_id="attack-results")
 
 
 
 
 
 
 
 
1515
  with gr.Accordion("Show Execution Trace", open=False):
1516
  trace_display = gr.Textbox(
1517
  lines=14,
@@ -1705,7 +1774,9 @@ Satya
1705
  gr.update(), # email1_display - no change
1706
  gr.update(), # email2_display - no change
1707
  gr.update(), # email3_display - no change
1708
- gr.update(value=modal_html, visible=True) # error_modal_html
 
 
1709
  )
1710
 
1711
  print("✅ ALL VALIDATION PASSED - proceeding with attack submission")
@@ -1717,7 +1788,7 @@ Satya
1717
  }
1718
 
1719
  try:
1720
- exec_log, final_out = submit_attack(from_addr.strip(), subject, body, model, defense_enabled, user_info.strip(), confidence_scores)
1721
  except Exception as e:
1722
  # Handle any setup or execution errors with detailed messages
1723
  error_str = str(e).lower()
@@ -1773,7 +1844,9 @@ Satya
1773
  gr.update(), # email1_display - no change
1774
  gr.update(), # email2_display - no change
1775
  gr.update(), # email3_display - no change
1776
- gr.update(value=modal_html, visible=True) # error_modal_html
 
 
1777
  )
1778
 
1779
  # Build a formatted results summary extracted from exec_log
@@ -1818,16 +1891,34 @@ Satya
1818
  for i, email in enumerate(emails_to_display):
1819
  updated_emails.append(format_single_email(email, i + 1))
1820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1821
  # Return results with hidden error modal (validation passed)
1822
  success_timestamp = int(time.time() * 1000)
1823
  print(f"✅ Validation successful at {success_timestamp} - hiding error modal")
1824
  return (final_out, results_html, exec_log, updated_emails[0], updated_emails[1], updated_emails[2],
1825
- gr.update(value="", visible=False)) # Hide error modal
 
 
1826
 
1827
  submit_btn.click(
1828
  fn=submit_and_update,
1829
  inputs=[attack_from, attack_subject, attack_body, model_selector, defense_toggle, user_info],
1830
- outputs=[final_output_display, results_display, trace_display, email1_display, email2_display, email3_display, error_modal_html]
1831
  )
1832
 
1833
  # Connect dismiss trigger to properly hide the modal
 
526
 
527
  return False # Passes soft gibberish checks
528
 
 
 
 
 
 
 
 
529
  def get_fasttext_confidence_scores(text, model=None, top_k=3):
530
  """
531
  Get top language confidence scores from FastText without doing validation.
 
754
  }
755
 
756
  # Process the fixed user query with the tool agent loop
757
+ execution_log, final_output, flagged_content = tool_agent_loop(
758
  user_query=USER_INPUT,
759
  inbox=INBOX,
760
  system_prompt=SYSTEM_PROMPT,
 
764
  fasttext_confidence_scores=fasttext_confidence_scores
765
  )
766
 
767
+ # Return execution log, final output, and flagged content separately
768
+ return execution_log, final_output, flagged_content
769
 
770
  except Exception as e:
771
  error_msg = f"❌ Error processing attack: {str(e)}"
772
  print(error_msg)
773
+ return "", error_msg, ""
774
 
775
  def reset_to_initial_state():
776
  """Reset the inbox to original state and clear all inputs"""
 
1168
  .results-card ul { margin: 0; padding-left: 16px; }
1169
  .results-card li { margin: 4px 0; }
1170
 
1171
+
1172
+
1173
+ /* Accordion content styling for flagged content */
1174
+ .gr-accordion .gr-panel:has([data-testid="HTML"]) {
1175
+ max-height: 300px !important;
1176
+ overflow-y: auto !important;
1177
+ padding: 16px !important;
1178
+ background: white !important;
1179
+ border-radius: 8px !important;
1180
+ font-family: 'Roboto', sans-serif !important;
1181
+ line-height: 1.6 !important;
1182
+ color: #333333 !important;
1183
+ word-wrap: break-word !important;
1184
+ overflow-wrap: break-word !important;
1185
+ scrollbar-width: thin !important;
1186
+ }
1187
+
1188
+ /* Scrollbar styling for accordion content */
1189
+ .gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar {
1190
+ width: 8px !important;
1191
+ }
1192
+
1193
+ .gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-track {
1194
+ background: rgba(0,0,0,0.1) !important;
1195
+ border-radius: 4px !important;
1196
+ }
1197
+
1198
+ .gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-thumb {
1199
+ background: rgba(0,0,0,0.3) !important;
1200
+ border-radius: 4px !important;
1201
+ }
1202
+
1203
+ .gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-thumb:hover {
1204
+ background: rgba(0,0,0,0.5) !important;
1205
+ }
1206
+
1207
+ /* Instruction tag styling for light mode */
1208
+ instruction {
1209
+ background-color: #ffebee !important;
1210
+ color: #c62828 !important;
1211
+ padding: 2px 6px !important;
1212
+ border-radius: 4px !important;
1213
+ font-weight: 600 !important;
1214
+ border: 1px solid #ef5350 !important;
1215
+ box-shadow: 0 1px 2px rgba(198, 40, 40, 0.2) !important;
1216
+ display: inline !important;
1217
+ font-family: 'Roboto', sans-serif !important;
1218
+ font-size: 14px !important;
1219
+ line-height: 1.4 !important;
1220
+ margin: 0 2px !important;
1221
+ }
1222
+
1223
+ /* Instruction tag styling for dark mode */
1224
+ @media (prefers-color-scheme: dark) {
1225
+ instruction {
1226
+ background-color: rgb(84 37 37) !important;
1227
+ color: #ffffff !important;
1228
+ border: 1px solid #d32f2f !important;
1229
+ box-shadow: 0 1px 3px rgba(183, 28, 28, 0.4) !important;
1230
+ }
1231
+
1232
+ /* Also ensure accordion content has proper dark mode styling */
1233
+ .gr-accordion .gr-panel:has([data-testid="HTML"]) {
1234
+ background: var(--background-fill-primary) !important;
1235
+ color: var(--body-text-color) !important;
1236
+ }
1237
+ }
1238
+
1239
  /* Error Modal Popup Styling */
1240
  .error-modal-overlay {
1241
  position: fixed !important;
 
1573
  )
1574
  # Attack results summary (pretty list)
1575
  results_display = gr.HTML("", elem_id="attack-results")
1576
+
1577
+ # Flagged content display (only shown when defense enabled and content found)
1578
+ with gr.Accordion("Show What was Flagged", open=False, visible=False) as flagged_accordion:
1579
+ flagged_content_display = gr.HTML(
1580
+ "",
1581
+ show_label=False
1582
+ )
1583
+
1584
  with gr.Accordion("Show Execution Trace", open=False):
1585
  trace_display = gr.Textbox(
1586
  lines=14,
 
1774
  gr.update(), # email1_display - no change
1775
  gr.update(), # email2_display - no change
1776
  gr.update(), # email3_display - no change
1777
+ gr.update(value=modal_html, visible=True), # error_modal_html
1778
+ gr.update(), # flagged_accordion - no change
1779
+ gr.update() # flagged_content_display - no change
1780
  )
1781
 
1782
  print("✅ ALL VALIDATION PASSED - proceeding with attack submission")
 
1788
  }
1789
 
1790
  try:
1791
+ exec_log, final_out, flagged_content = submit_attack(from_addr.strip(), subject, body, model, defense_enabled, user_info.strip(), confidence_scores)
1792
  except Exception as e:
1793
  # Handle any setup or execution errors with detailed messages
1794
  error_str = str(e).lower()
 
1844
  gr.update(), # email1_display - no change
1845
  gr.update(), # email2_display - no change
1846
  gr.update(), # email3_display - no change
1847
+ gr.update(value=modal_html, visible=True), # error_modal_html
1848
+ gr.update(), # flagged_accordion - no change
1849
+ gr.update() # flagged_content_display - no change
1850
  )
1851
 
1852
  # Build a formatted results summary extracted from exec_log
 
1891
  for i, email in enumerate(emails_to_display):
1892
  updated_emails.append(format_single_email(email, i + 1))
1893
 
1894
+ # Process flagged content for display
1895
+ flagged_display_html = ""
1896
+ flagged_accordion_visible = False
1897
+ flagged_accordion_open = False
1898
+
1899
+ if defense_enabled and flagged_content and flagged_content.strip():
1900
+ # Convert newlines to HTML line breaks for proper rendering
1901
+ flagged_content_html = flagged_content.replace('\n', '<br>')
1902
+ # Simple HTML structure without extra containers
1903
+ flagged_display_html = flagged_content_html
1904
+ flagged_accordion_visible = True
1905
+ flagged_accordion_open = True # Open after submit when there's content
1906
+ print(f"🏷️ Flagged content prepared for UI: {len(flagged_content)} characters")
1907
+ else:
1908
+ print("🏷️ No flagged content to display")
1909
+
1910
  # Return results with hidden error modal (validation passed)
1911
  success_timestamp = int(time.time() * 1000)
1912
  print(f"✅ Validation successful at {success_timestamp} - hiding error modal")
1913
  return (final_out, results_html, exec_log, updated_emails[0], updated_emails[1], updated_emails[2],
1914
+ gr.update(value="", visible=False), # Hide error modal
1915
+ gr.update(visible=flagged_accordion_visible, open=flagged_accordion_open), # Update flagged accordion
1916
+ gr.update(value=flagged_display_html)) # Update flagged content
1917
 
1918
  submit_btn.click(
1919
  fn=submit_and_update,
1920
  inputs=[attack_from, attack_subject, attack_body, model_selector, defense_toggle, user_info],
1921
+ outputs=[final_output_display, results_display, trace_display, email1_display, email2_display, email3_display, error_modal_html, flagged_accordion, flagged_content_display]
1922
  )
1923
 
1924
  # Connect dismiss trigger to properly hide the modal
instruction_classifier.py CHANGED
@@ -186,7 +186,7 @@ class InstructionClassifierSanitizer:
186
 
187
 
188
  @spaces.GPU
189
- def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
190
  """
191
  Sanitization function that also returns annotation data for flagged content.
192
 
@@ -194,11 +194,13 @@ class InstructionClassifierSanitizer:
194
  tool_output: The raw tool output string
195
 
196
  Returns:
197
- Tuple of (sanitized_output, annotations) where annotations contain
198
- position information for content that was flagged by the classifier
 
 
199
  """
200
  if not tool_output or not tool_output.strip():
201
- return tool_output, []
202
 
203
  # Move model to target device (GPU) within @spaces.GPU decorated method
204
  if self.device != self.target_device:
@@ -214,7 +216,7 @@ class InstructionClassifierSanitizer:
214
 
215
  if not is_injection:
216
  print("✅ No injection detected - returning original output")
217
- return tool_output, []
218
 
219
  print(f"🚨 Injection detected! Processing with extensions and annotations...")
220
 
@@ -233,12 +235,13 @@ class InstructionClassifierSanitizer:
233
  # Step 5: Remove instruction tags and their content
234
  sanitized_output = self._remove_instruction_tags(merged_tagged_text)
235
  print(f"🔒 Sanitized output: {sanitized_output}")
236
- return sanitized_output, annotations
 
237
 
238
  except Exception as e:
239
  print(f"❌ Error in instruction classifier sanitization: {e}")
240
  # Return original output if sanitization fails
241
- return tool_output, []
242
 
243
  def _extract_annotations_from_tagged_text(self, tagged_text: str, original_text: str) -> List[Dict[str, any]]:
244
  """
@@ -700,23 +703,25 @@ def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
700
  defense_enabled: Whether defense is enabled (passed from agent)
701
 
702
  Returns:
703
- Tuple of (sanitized_output, annotations) where annotations contain
704
- position information for content that was flagged by the classifier
 
 
705
  """
706
  print(f"🔍 sanitize_tool_output_with_annotations called with: {tool_output[:100]}...")
707
 
708
  # If defense is disabled globally, return original output with no annotations
709
  if not defense_enabled:
710
  print("⚠️ Defense disabled - returning original output without processing")
711
- return tool_output, []
712
 
713
  sanitizer = get_sanitizer()
714
  if sanitizer is None:
715
  print("⚠️ Instruction classifier not available, returning original output")
716
- return tool_output, []
717
 
718
  print("✅ Sanitizer found, processing with annotations...")
719
- sanitized_output, annotations = sanitizer.sanitize_with_annotations(tool_output)
720
  print(f"🔒 Sanitization complete, result: {sanitized_output[:100]}...")
721
  print(f"📝 Found {len(annotations)} annotations")
722
- return sanitized_output, annotations
 
186
 
187
 
188
  @spaces.GPU
189
+ def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]], str]:
190
  """
191
  Sanitization function that also returns annotation data for flagged content.
192
 
 
194
  tool_output: The raw tool output string
195
 
196
  Returns:
197
+ Tuple of (sanitized_output, annotations, merged_tagged_text) where:
198
+ - sanitized_output: cleaned text with instruction content removed
199
+ - annotations: position information for content flagged by classifier
200
+ - merged_tagged_text: text with <instruction> tags showing detected content
201
  """
202
  if not tool_output or not tool_output.strip():
203
+ return tool_output, [], tool_output
204
 
205
  # Move model to target device (GPU) within @spaces.GPU decorated method
206
  if self.device != self.target_device:
 
216
 
217
  if not is_injection:
218
  print("✅ No injection detected - returning original output")
219
+ return tool_output, [], tool_output
220
 
221
  print(f"🚨 Injection detected! Processing with extensions and annotations...")
222
 
 
235
  # Step 5: Remove instruction tags and their content
236
  sanitized_output = self._remove_instruction_tags(merged_tagged_text)
237
  print(f"🔒 Sanitized output: {sanitized_output}")
238
+ print(f"🔍 DEBUG SANITIZER: Returning merged_tagged_text: '{merged_tagged_text}'")
239
+ return sanitized_output, annotations, merged_tagged_text
240
 
241
  except Exception as e:
242
  print(f"❌ Error in instruction classifier sanitization: {e}")
243
  # Return original output if sanitization fails
244
+ return tool_output, [], tool_output
245
 
246
  def _extract_annotations_from_tagged_text(self, tagged_text: str, original_text: str) -> List[Dict[str, any]]:
247
  """
 
703
  defense_enabled: Whether defense is enabled (passed from agent)
704
 
705
  Returns:
706
+ Tuple of (sanitized_output, annotations, merged_tagged_text) where:
707
+ - sanitized_output: cleaned text with instruction content removed
708
+ - annotations: position information for content flagged by classifier
709
+ - merged_tagged_text: text with <instruction> tags showing detected content
710
  """
711
  print(f"🔍 sanitize_tool_output_with_annotations called with: {tool_output[:100]}...")
712
 
713
  # If defense is disabled globally, return original output with no annotations
714
  if not defense_enabled:
715
  print("⚠️ Defense disabled - returning original output without processing")
716
+ return tool_output, [], tool_output
717
 
718
  sanitizer = get_sanitizer()
719
  if sanitizer is None:
720
  print("⚠️ Instruction classifier not available, returning original output")
721
+ return tool_output, [], tool_output
722
 
723
  print("✅ Sanitizer found, processing with annotations...")
724
+ sanitized_output, annotations, merged_tagged_text = sanitizer.sanitize_with_annotations(tool_output)
725
  print(f"🔒 Sanitization complete, result: {sanitized_output[:100]}...")
726
  print(f"📝 Found {len(annotations)} annotations")
727
+ return sanitized_output, annotations, merged_tagged_text
utils.py CHANGED
@@ -1,147 +1,19 @@
1
  import json
2
  import torch
3
  import torch.nn as nn
4
- from torch.utils.data import Dataset, DataLoader
5
- from transformers import AutoTokenizer, AutoModel, AutoConfig
6
- import numpy as np
7
- from tqdm import tqdm
8
  import re
9
- from typing import List, Tuple, Dict, Any
10
  import warnings
11
  import logging
12
  import os
13
- from datetime import datetime
14
- from sklearn.utils.class_weight import compute_class_weight
15
- import torch.nn.functional as F
16
 
17
  # Disable tokenizer parallelism to avoid forking warnings
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
 
20
  warnings.filterwarnings('ignore')
21
 
22
- def set_random_seeds(seed=42):
23
- """Set random seeds for reproducibility"""
24
- import random
25
- import numpy as np
26
- import torch
27
-
28
- random.seed(seed)
29
- np.random.seed(seed)
30
- torch.manual_seed(seed)
31
- torch.cuda.manual_seed_all(seed) # For multi-GPU
32
-
33
- # Make CuDNN deterministic (slower but reproducible)
34
- torch.backends.cudnn.deterministic = True
35
- torch.backends.cudnn.benchmark = False
36
-
37
- def setup_logging(log_dir='data/logs'):
38
- """Setup logging configuration"""
39
- # Create logs directory if it doesn't exist
40
- os.makedirs(log_dir, exist_ok=True)
41
-
42
- # Create timestamp for log file
43
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
44
- log_file = os.path.join(log_dir, f'training_log_{timestamp}.log')
45
-
46
- # Configure logging
47
- logging.basicConfig(
48
- level=logging.INFO, # Back to INFO level
49
- format='%(asctime)s - %(levelname)s - %(message)s',
50
- handlers=[
51
- logging.FileHandler(log_file),
52
- logging.StreamHandler() # Also print to console
53
- ]
54
- )
55
-
56
- logger = logging.getLogger(__name__)
57
- logger.info(f"Logging initialized. Log file: {log_file}")
58
- return logger, log_file
59
-
60
- def check_gpu_availability():
61
- """Check and print GPU availability information"""
62
- logger = logging.getLogger(__name__)
63
- logger.info("=== GPU Availability Check ===")
64
-
65
- if torch.backends.mps.is_available():
66
- logger.info("✓ MPS (Apple Silicon GPU) is available")
67
- if torch.backends.mps.is_built():
68
- logger.info("✓ MPS is built into PyTorch")
69
- else:
70
- logger.info("✗ MPS is not built into PyTorch")
71
- else:
72
- logger.info("✗ MPS (Apple Silicon GPU) is not available")
73
-
74
- if torch.cuda.is_available():
75
- logger.info(f"✓ CUDA is available (GPU count: {torch.cuda.device_count()})")
76
- else:
77
- logger.info("✗ CUDA is not available")
78
-
79
- logger.info(f"PyTorch version: {torch.__version__}")
80
- logger.info("=" * 50)
81
-
82
- def calculate_class_weights(dataset):
83
- """Calculate class weights for imbalanced dataset using BERT paper approach"""
84
- logger = logging.getLogger(__name__)
85
-
86
- # Collect all labels from the dataset (BERT approach: only first subtokens have real labels)
87
- all_labels = []
88
- for window_data in dataset.processed_data:
89
- # Filter out -100 labels (special tokens + subsequent subtokens of same word)
90
- # This gives us true word-level class distribution
91
- valid_labels = [label for label in window_data['subword_labels'] if label != -100]
92
- all_labels.extend(valid_labels)
93
-
94
- # Convert to numpy array
95
- y = np.array(all_labels)
96
-
97
- # Calculate class weights using sklearn
98
- classes = np.unique(y)
99
- class_weights = compute_class_weight('balanced', classes=classes, y=y)
100
-
101
- # Create weight tensor
102
- weight_tensor = torch.FloatTensor(class_weights)
103
-
104
- logger.info(f"Word-level class distribution: {np.bincount(y)}")
105
- logger.info(f"Class 0 (Non-instruction words): {np.sum(y == 0)} words ({np.sum(y == 0)/len(y)*100:.1f}%)")
106
- logger.info(f"Class 1 (Instruction words): {np.sum(y == 1)} words ({np.sum(y == 1)/len(y)*100:.1f}%)")
107
- logger.info(f"Calculated class weights (word-level): {class_weights}")
108
- logger.info(f" Weight for class 0 (Non-instruction): {class_weights[0]:.4f}")
109
- logger.info(f" Weight for class 1 (Instruction): {class_weights[1]:.4f}")
110
-
111
- return weight_tensor
112
-
113
- class FocalLoss(nn.Module):
114
- """Focal Loss for addressing class imbalance"""
115
- def __init__(self, alpha=1, gamma=2, ignore_index=-100):
116
- super(FocalLoss, self).__init__()
117
- self.alpha = alpha
118
- self.gamma = gamma
119
- self.ignore_index = ignore_index
120
-
121
- def forward(self, inputs, targets):
122
- # Flatten inputs and targets
123
- inputs = inputs.view(-1, inputs.size(-1))
124
- targets = targets.view(-1)
125
-
126
- # Create mask for non-ignored indices
127
- mask = targets != self.ignore_index
128
- targets = targets[mask]
129
- inputs = inputs[mask]
130
-
131
- if len(targets) == 0:
132
- return torch.tensor(0.0, requires_grad=True, device=inputs.device)
133
-
134
- # Calculate cross entropy
135
- ce_loss = F.cross_entropy(inputs, targets, reduction='none')
136
-
137
- # Calculate pt
138
- pt = torch.exp(-ce_loss)
139
-
140
- # Calculate focal loss
141
- focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
142
-
143
- return focal_loss.mean()
144
-
145
  class InstructionDataset(Dataset):
146
  def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
147
  window_size: int = 512, overlap: int = 100):
@@ -517,8 +389,6 @@ class TransformerInstructionClassifier(nn.Module):
517
  # Setup loss function based on type
518
  if loss_type == 'weighted_ce':
519
  self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
520
- elif loss_type == 'focal':
521
- self.loss_fct = FocalLoss(alpha=1, gamma=2, ignore_index=-100)
522
  else:
523
  self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
524
 
 
1
  import json
2
  import torch
3
  import torch.nn as nn
4
+ from torch.utils.data import Dataset
5
+ from transformers import AutoModel
 
 
6
  import re
7
+ from typing import List, Dict, Any
8
  import warnings
9
  import logging
10
  import os
 
 
 
11
 
12
  # Disable tokenizer parallelism to avoid forking warnings
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
 
15
  warnings.filterwarnings('ignore')
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class InstructionDataset(Dataset):
18
  def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
19
  window_size: int = 512, overlap: int = 100):
 
389
  # Setup loss function based on type
390
  if loss_type == 'weighted_ce':
391
  self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
 
 
392
  else:
393
  self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
394