Spaces:
Running
on
Zero
Running
on
Zero
UI tagged preview added
Browse files- agent.py +67 -28
- app.py +107 -16
- instruction_classifier.py +18 -13
- utils.py +3 -133
agent.py
CHANGED
|
@@ -551,26 +551,6 @@ Body: {email.body_value}"""
|
|
| 551 |
# Import the instruction classifier sanitizer
|
| 552 |
from instruction_classifier import sanitize_tool_output_with_annotations
|
| 553 |
|
| 554 |
-
|
| 555 |
-
def extract_tool_calls(text):
|
| 556 |
-
"""Extract tool calls from LLM output (legacy function - kept for compatibility)"""
|
| 557 |
-
tool_calls = []
|
| 558 |
-
|
| 559 |
-
# Patterns to match tool calls
|
| 560 |
-
patterns = [
|
| 561 |
-
r'get_emails\(\)',
|
| 562 |
-
r'search_email\(keyword=[^)]*\)', # search_email(keyword="UBS")
|
| 563 |
-
r'search_email\(\s*"[^"]+"\s*\)', # search_email("UBS")
|
| 564 |
-
r'send_email\([^)]+\)'
|
| 565 |
-
]
|
| 566 |
-
|
| 567 |
-
for pattern in patterns:
|
| 568 |
-
matches = re.findall(pattern, text)
|
| 569 |
-
tool_calls.extend(matches)
|
| 570 |
-
|
| 571 |
-
return tool_calls
|
| 572 |
-
|
| 573 |
-
|
| 574 |
def extract_and_parse_tool_calls(text):
|
| 575 |
"""
|
| 576 |
Extract tool calls from LLM output and parse them into structured format
|
|
@@ -714,6 +694,38 @@ def create_assistant_message_with_tool_calls(llm_output, parsed_tool_calls, prov
|
|
| 714 |
return {"role": "assistant", "content": llm_output}
|
| 715 |
|
| 716 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
def create_tool_result_message(tool_results, provider):
|
| 718 |
"""
|
| 719 |
Create properly formatted tool result message based on LLM provider
|
|
@@ -843,6 +855,9 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
| 843 |
# Track annotations for instruction classifier flagged content
|
| 844 |
all_annotations = []
|
| 845 |
|
|
|
|
|
|
|
|
|
|
| 846 |
# Initialize conversation with system prompt and user query
|
| 847 |
# This will be used for LLM API calls (provider-specific format)
|
| 848 |
llm_messages = [
|
|
@@ -968,7 +983,17 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
| 968 |
# Conditional sanitization based on defense setting
|
| 969 |
if defense_enabled:
|
| 970 |
# Sanitize tool output with annotations
|
| 971 |
-
sanitized_output, annotations = sanitize_tool_output_with_annotations(tool_output, defense_enabled)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
|
| 973 |
# Always add raw tool output to trace when defense is enabled
|
| 974 |
raw_tool_message = {
|
|
@@ -988,17 +1013,26 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
| 988 |
|
| 989 |
# Add annotations to our collection
|
| 990 |
all_annotations.extend(annotations)
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
|
|
|
|
|
|
| 995 |
|
| 996 |
-
#
|
| 997 |
sanitized_msg = f"🔒 Sanitized Result: {sanitized_output}"
|
| 998 |
execution_log.append(sanitized_msg)
|
| 999 |
print(sanitized_msg)
|
| 1000 |
|
| 1001 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
sanitized_tool_message = {
|
| 1003 |
"role": "tool",
|
| 1004 |
"tool_call_id": tool_call_info['id'],
|
|
@@ -1144,4 +1178,9 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
| 1144 |
final_trace_msg = f"📊 Trace push completed (with {len(all_annotations)} annotations)"
|
| 1145 |
execution_log.append(final_trace_msg)
|
| 1146 |
|
| 1147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
# Import the instruction classifier sanitizer
|
| 552 |
from instruction_classifier import sanitize_tool_output_with_annotations
|
| 553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
def extract_and_parse_tool_calls(text):
|
| 555 |
"""
|
| 556 |
Extract tool calls from LLM output and parse them into structured format
|
|
|
|
| 694 |
return {"role": "assistant", "content": llm_output}
|
| 695 |
|
| 696 |
|
| 697 |
+
def process_flagged_content(merged_tagged_text):
|
| 698 |
+
"""
|
| 699 |
+
Process tagged text to find second 'To: [email protected]' and trim before it.
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
merged_tagged_text (str): Text with <instruction> tags from classifier
|
| 703 |
+
|
| 704 |
+
Returns:
|
| 705 |
+
str: Processed text with content before second occurrence removed
|
| 706 |
+
"""
|
| 707 |
+
if not merged_tagged_text:
|
| 708 |
+
return ""
|
| 709 |
+
|
| 710 |
+
# Find second occurrence of "To: [email protected]"
|
| 711 |
+
target_phrase = "To: [email protected]"
|
| 712 |
+
first_pos = merged_tagged_text.find(target_phrase)
|
| 713 |
+
if first_pos != -1:
|
| 714 |
+
second_pos = merged_tagged_text.find(target_phrase, first_pos + 1)
|
| 715 |
+
if second_pos != -1:
|
| 716 |
+
# Remove everything before and including the second occurrence
|
| 717 |
+
processed_text = merged_tagged_text[second_pos + len(target_phrase):].strip()
|
| 718 |
+
print(f"🏷️ Found second occurrence at position {second_pos}, processed flagged content: {processed_text[:100]}...")
|
| 719 |
+
# Insert newline before "Time:" and "Body:" (but not if already at start of line)
|
| 720 |
+
processed_text = re.sub(r'(?<!\n)(Time:)', r'\n\1', processed_text)
|
| 721 |
+
processed_text = re.sub(r'(?<!\n)(Body:)', r'\n\1', processed_text)
|
| 722 |
+
return processed_text
|
| 723 |
+
|
| 724 |
+
# If no second occurrence, return entire text
|
| 725 |
+
print(f"🏷️ No second occurrence found, returning entire flagged content: {merged_tagged_text[:100]}...")
|
| 726 |
+
return merged_tagged_text
|
| 727 |
+
|
| 728 |
+
|
| 729 |
def create_tool_result_message(tool_results, provider):
|
| 730 |
"""
|
| 731 |
Create properly formatted tool result message based on LLM provider
|
|
|
|
| 855 |
# Track annotations for instruction classifier flagged content
|
| 856 |
all_annotations = []
|
| 857 |
|
| 858 |
+
# Track flagged content for UI display
|
| 859 |
+
all_flagged_content = []
|
| 860 |
+
|
| 861 |
# Initialize conversation with system prompt and user query
|
| 862 |
# This will be used for LLM API calls (provider-specific format)
|
| 863 |
llm_messages = [
|
|
|
|
| 983 |
# Conditional sanitization based on defense setting
|
| 984 |
if defense_enabled:
|
| 985 |
# Sanitize tool output with annotations
|
| 986 |
+
sanitized_output, annotations, merged_tagged_text = sanitize_tool_output_with_annotations(tool_output, defense_enabled)
|
| 987 |
+
|
| 988 |
+
# Process and collect flagged content for UI display
|
| 989 |
+
print(f"🔍 DEBUG: merged_tagged_text: {merged_tagged_text}")
|
| 990 |
+
print(f"🔍 DEBUG: has <instruction> tags: {'<instruction>' in merged_tagged_text if merged_tagged_text else 'No text'}")
|
| 991 |
+
if merged_tagged_text and merged_tagged_text.strip() and "<instruction>" in merged_tagged_text:
|
| 992 |
+
processed_flagged = process_flagged_content(merged_tagged_text)
|
| 993 |
+
print(f"🔍 DEBUG: processed_flagged result: {processed_flagged}")
|
| 994 |
+
if processed_flagged:
|
| 995 |
+
all_flagged_content.append(processed_flagged)
|
| 996 |
+
print(f"🔍 DEBUG: Added to all_flagged_content. Total items: {len(all_flagged_content)}")
|
| 997 |
|
| 998 |
# Always add raw tool output to trace when defense is enabled
|
| 999 |
raw_tool_message = {
|
|
|
|
| 1013 |
|
| 1014 |
# Add annotations to our collection
|
| 1015 |
all_annotations.extend(annotations)
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
# Add some spacing before sanitized output for clarity
|
| 1019 |
+
execution_log.append("")
|
| 1020 |
+
execution_log.append("--- DEFENSE PROCESSING ---")
|
| 1021 |
+
execution_log.append("")
|
| 1022 |
|
| 1023 |
+
# Show sanitized result in logs when defense is enabled
|
| 1024 |
sanitized_msg = f"🔒 Sanitized Result: {sanitized_output}"
|
| 1025 |
execution_log.append(sanitized_msg)
|
| 1026 |
print(sanitized_msg)
|
| 1027 |
|
| 1028 |
+
# Add spacing separator in trace for clarity
|
| 1029 |
+
separator_message = {
|
| 1030 |
+
"role": "system",
|
| 1031 |
+
"content": "--- DEFENSE SANITIZATION APPLIED ---"
|
| 1032 |
+
}
|
| 1033 |
+
trace_messages.append(separator_message)
|
| 1034 |
+
|
| 1035 |
+
# Add sanitized tool output to trace when defense is enabled
|
| 1036 |
sanitized_tool_message = {
|
| 1037 |
"role": "tool",
|
| 1038 |
"tool_call_id": tool_call_info['id'],
|
|
|
|
| 1178 |
final_trace_msg = f"📊 Trace push completed (with {len(all_annotations)} annotations)"
|
| 1179 |
execution_log.append(final_trace_msg)
|
| 1180 |
|
| 1181 |
+
# Combine all flagged content for UI display
|
| 1182 |
+
combined_flagged_content = "\n\n".join(all_flagged_content) if all_flagged_content else ""
|
| 1183 |
+
print(f"🔍 DEBUG: Final combined_flagged_content: '{combined_flagged_content}'")
|
| 1184 |
+
print(f"🔍 DEBUG: Length: {len(combined_flagged_content)} characters")
|
| 1185 |
+
|
| 1186 |
+
return "\n".join(execution_log), llm_output, combined_flagged_content
|
app.py
CHANGED
|
@@ -526,13 +526,6 @@ def is_likely_gibberish_soft(text):
|
|
| 526 |
|
| 527 |
return False # Passes soft gibberish checks
|
| 528 |
|
| 529 |
-
def validate_english_with_model_loading(text):
|
| 530 |
-
"""
|
| 531 |
-
Convenience function that handles FastText model loading automatically.
|
| 532 |
-
"""
|
| 533 |
-
model = load_fasttext_model() # This will download and load the model if needed
|
| 534 |
-
return validate_english_only_windowed(text, model)
|
| 535 |
-
|
| 536 |
def get_fasttext_confidence_scores(text, model=None, top_k=3):
|
| 537 |
"""
|
| 538 |
Get top language confidence scores from FastText without doing validation.
|
|
@@ -761,7 +754,7 @@ def submit_attack(from_addr, attack_subject, attack_body, model_name="gpt-4o", d
|
|
| 761 |
}
|
| 762 |
|
| 763 |
# Process the fixed user query with the tool agent loop
|
| 764 |
-
execution_log, final_output = tool_agent_loop(
|
| 765 |
user_query=USER_INPUT,
|
| 766 |
inbox=INBOX,
|
| 767 |
system_prompt=SYSTEM_PROMPT,
|
|
@@ -771,13 +764,13 @@ def submit_attack(from_addr, attack_subject, attack_body, model_name="gpt-4o", d
|
|
| 771 |
fasttext_confidence_scores=fasttext_confidence_scores
|
| 772 |
)
|
| 773 |
|
| 774 |
-
# Return execution log
|
| 775 |
-
return execution_log, final_output
|
| 776 |
|
| 777 |
except Exception as e:
|
| 778 |
error_msg = f"❌ Error processing attack: {str(e)}"
|
| 779 |
print(error_msg)
|
| 780 |
-
return "", error_msg
|
| 781 |
|
| 782 |
def reset_to_initial_state():
|
| 783 |
"""Reset the inbox to original state and clear all inputs"""
|
|
@@ -1175,6 +1168,74 @@ def create_interface():
|
|
| 1175 |
.results-card ul { margin: 0; padding-left: 16px; }
|
| 1176 |
.results-card li { margin: 4px 0; }
|
| 1177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1178 |
/* Error Modal Popup Styling */
|
| 1179 |
.error-modal-overlay {
|
| 1180 |
position: fixed !important;
|
|
@@ -1512,6 +1573,14 @@ Satya
|
|
| 1512 |
)
|
| 1513 |
# Attack results summary (pretty list)
|
| 1514 |
results_display = gr.HTML("", elem_id="attack-results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1515 |
with gr.Accordion("Show Execution Trace", open=False):
|
| 1516 |
trace_display = gr.Textbox(
|
| 1517 |
lines=14,
|
|
@@ -1705,7 +1774,9 @@ Satya
|
|
| 1705 |
gr.update(), # email1_display - no change
|
| 1706 |
gr.update(), # email2_display - no change
|
| 1707 |
gr.update(), # email3_display - no change
|
| 1708 |
-
gr.update(value=modal_html, visible=True) # error_modal_html
|
|
|
|
|
|
|
| 1709 |
)
|
| 1710 |
|
| 1711 |
print("✅ ALL VALIDATION PASSED - proceeding with attack submission")
|
|
@@ -1717,7 +1788,7 @@ Satya
|
|
| 1717 |
}
|
| 1718 |
|
| 1719 |
try:
|
| 1720 |
-
exec_log, final_out = submit_attack(from_addr.strip(), subject, body, model, defense_enabled, user_info.strip(), confidence_scores)
|
| 1721 |
except Exception as e:
|
| 1722 |
# Handle any setup or execution errors with detailed messages
|
| 1723 |
error_str = str(e).lower()
|
|
@@ -1773,7 +1844,9 @@ Satya
|
|
| 1773 |
gr.update(), # email1_display - no change
|
| 1774 |
gr.update(), # email2_display - no change
|
| 1775 |
gr.update(), # email3_display - no change
|
| 1776 |
-
gr.update(value=modal_html, visible=True) # error_modal_html
|
|
|
|
|
|
|
| 1777 |
)
|
| 1778 |
|
| 1779 |
# Build a formatted results summary extracted from exec_log
|
|
@@ -1818,16 +1891,34 @@ Satya
|
|
| 1818 |
for i, email in enumerate(emails_to_display):
|
| 1819 |
updated_emails.append(format_single_email(email, i + 1))
|
| 1820 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1821 |
# Return results with hidden error modal (validation passed)
|
| 1822 |
success_timestamp = int(time.time() * 1000)
|
| 1823 |
print(f"✅ Validation successful at {success_timestamp} - hiding error modal")
|
| 1824 |
return (final_out, results_html, exec_log, updated_emails[0], updated_emails[1], updated_emails[2],
|
| 1825 |
-
gr.update(value="", visible=False)
|
|
|
|
|
|
|
| 1826 |
|
| 1827 |
submit_btn.click(
|
| 1828 |
fn=submit_and_update,
|
| 1829 |
inputs=[attack_from, attack_subject, attack_body, model_selector, defense_toggle, user_info],
|
| 1830 |
-
outputs=[final_output_display, results_display, trace_display, email1_display, email2_display, email3_display, error_modal_html]
|
| 1831 |
)
|
| 1832 |
|
| 1833 |
# Connect dismiss trigger to properly hide the modal
|
|
|
|
| 526 |
|
| 527 |
return False # Passes soft gibberish checks
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
def get_fasttext_confidence_scores(text, model=None, top_k=3):
|
| 530 |
"""
|
| 531 |
Get top language confidence scores from FastText without doing validation.
|
|
|
|
| 754 |
}
|
| 755 |
|
| 756 |
# Process the fixed user query with the tool agent loop
|
| 757 |
+
execution_log, final_output, flagged_content = tool_agent_loop(
|
| 758 |
user_query=USER_INPUT,
|
| 759 |
inbox=INBOX,
|
| 760 |
system_prompt=SYSTEM_PROMPT,
|
|
|
|
| 764 |
fasttext_confidence_scores=fasttext_confidence_scores
|
| 765 |
)
|
| 766 |
|
| 767 |
+
# Return execution log, final output, and flagged content separately
|
| 768 |
+
return execution_log, final_output, flagged_content
|
| 769 |
|
| 770 |
except Exception as e:
|
| 771 |
error_msg = f"❌ Error processing attack: {str(e)}"
|
| 772 |
print(error_msg)
|
| 773 |
+
return "", error_msg, ""
|
| 774 |
|
| 775 |
def reset_to_initial_state():
|
| 776 |
"""Reset the inbox to original state and clear all inputs"""
|
|
|
|
| 1168 |
.results-card ul { margin: 0; padding-left: 16px; }
|
| 1169 |
.results-card li { margin: 4px 0; }
|
| 1170 |
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
/* Accordion content styling for flagged content */
|
| 1174 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"]) {
|
| 1175 |
+
max-height: 300px !important;
|
| 1176 |
+
overflow-y: auto !important;
|
| 1177 |
+
padding: 16px !important;
|
| 1178 |
+
background: white !important;
|
| 1179 |
+
border-radius: 8px !important;
|
| 1180 |
+
font-family: 'Roboto', sans-serif !important;
|
| 1181 |
+
line-height: 1.6 !important;
|
| 1182 |
+
color: #333333 !important;
|
| 1183 |
+
word-wrap: break-word !important;
|
| 1184 |
+
overflow-wrap: break-word !important;
|
| 1185 |
+
scrollbar-width: thin !important;
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
/* Scrollbar styling for accordion content */
|
| 1189 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar {
|
| 1190 |
+
width: 8px !important;
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-track {
|
| 1194 |
+
background: rgba(0,0,0,0.1) !important;
|
| 1195 |
+
border-radius: 4px !important;
|
| 1196 |
+
}
|
| 1197 |
+
|
| 1198 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-thumb {
|
| 1199 |
+
background: rgba(0,0,0,0.3) !important;
|
| 1200 |
+
border-radius: 4px !important;
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-thumb:hover {
|
| 1204 |
+
background: rgba(0,0,0,0.5) !important;
|
| 1205 |
+
}
|
| 1206 |
+
|
| 1207 |
+
/* Instruction tag styling for light mode */
|
| 1208 |
+
instruction {
|
| 1209 |
+
background-color: #ffebee !important;
|
| 1210 |
+
color: #c62828 !important;
|
| 1211 |
+
padding: 2px 6px !important;
|
| 1212 |
+
border-radius: 4px !important;
|
| 1213 |
+
font-weight: 600 !important;
|
| 1214 |
+
border: 1px solid #ef5350 !important;
|
| 1215 |
+
box-shadow: 0 1px 2px rgba(198, 40, 40, 0.2) !important;
|
| 1216 |
+
display: inline !important;
|
| 1217 |
+
font-family: 'Roboto', sans-serif !important;
|
| 1218 |
+
font-size: 14px !important;
|
| 1219 |
+
line-height: 1.4 !important;
|
| 1220 |
+
margin: 0 2px !important;
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
/* Instruction tag styling for dark mode */
|
| 1224 |
+
@media (prefers-color-scheme: dark) {
|
| 1225 |
+
instruction {
|
| 1226 |
+
background-color: rgb(84 37 37) !important;
|
| 1227 |
+
color: #ffffff !important;
|
| 1228 |
+
border: 1px solid #d32f2f !important;
|
| 1229 |
+
box-shadow: 0 1px 3px rgba(183, 28, 28, 0.4) !important;
|
| 1230 |
+
}
|
| 1231 |
+
|
| 1232 |
+
/* Also ensure accordion content has proper dark mode styling */
|
| 1233 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"]) {
|
| 1234 |
+
background: var(--background-fill-primary) !important;
|
| 1235 |
+
color: var(--body-text-color) !important;
|
| 1236 |
+
}
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
/* Error Modal Popup Styling */
|
| 1240 |
.error-modal-overlay {
|
| 1241 |
position: fixed !important;
|
|
|
|
| 1573 |
)
|
| 1574 |
# Attack results summary (pretty list)
|
| 1575 |
results_display = gr.HTML("", elem_id="attack-results")
|
| 1576 |
+
|
| 1577 |
+
# Flagged content display (only shown when defense enabled and content found)
|
| 1578 |
+
with gr.Accordion("Show What was Flagged", open=False, visible=False) as flagged_accordion:
|
| 1579 |
+
flagged_content_display = gr.HTML(
|
| 1580 |
+
"",
|
| 1581 |
+
show_label=False
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
with gr.Accordion("Show Execution Trace", open=False):
|
| 1585 |
trace_display = gr.Textbox(
|
| 1586 |
lines=14,
|
|
|
|
| 1774 |
gr.update(), # email1_display - no change
|
| 1775 |
gr.update(), # email2_display - no change
|
| 1776 |
gr.update(), # email3_display - no change
|
| 1777 |
+
gr.update(value=modal_html, visible=True), # error_modal_html
|
| 1778 |
+
gr.update(), # flagged_accordion - no change
|
| 1779 |
+
gr.update() # flagged_content_display - no change
|
| 1780 |
)
|
| 1781 |
|
| 1782 |
print("✅ ALL VALIDATION PASSED - proceeding with attack submission")
|
|
|
|
| 1788 |
}
|
| 1789 |
|
| 1790 |
try:
|
| 1791 |
+
exec_log, final_out, flagged_content = submit_attack(from_addr.strip(), subject, body, model, defense_enabled, user_info.strip(), confidence_scores)
|
| 1792 |
except Exception as e:
|
| 1793 |
# Handle any setup or execution errors with detailed messages
|
| 1794 |
error_str = str(e).lower()
|
|
|
|
| 1844 |
gr.update(), # email1_display - no change
|
| 1845 |
gr.update(), # email2_display - no change
|
| 1846 |
gr.update(), # email3_display - no change
|
| 1847 |
+
gr.update(value=modal_html, visible=True), # error_modal_html
|
| 1848 |
+
gr.update(), # flagged_accordion - no change
|
| 1849 |
+
gr.update() # flagged_content_display - no change
|
| 1850 |
)
|
| 1851 |
|
| 1852 |
# Build a formatted results summary extracted from exec_log
|
|
|
|
| 1891 |
for i, email in enumerate(emails_to_display):
|
| 1892 |
updated_emails.append(format_single_email(email, i + 1))
|
| 1893 |
|
| 1894 |
+
# Process flagged content for display
|
| 1895 |
+
flagged_display_html = ""
|
| 1896 |
+
flagged_accordion_visible = False
|
| 1897 |
+
flagged_accordion_open = False
|
| 1898 |
+
|
| 1899 |
+
if defense_enabled and flagged_content and flagged_content.strip():
|
| 1900 |
+
# Convert newlines to HTML line breaks for proper rendering
|
| 1901 |
+
flagged_content_html = flagged_content.replace('\n', '<br>')
|
| 1902 |
+
# Simple HTML structure without extra containers
|
| 1903 |
+
flagged_display_html = flagged_content_html
|
| 1904 |
+
flagged_accordion_visible = True
|
| 1905 |
+
flagged_accordion_open = True # Open after submit when there's content
|
| 1906 |
+
print(f"🏷️ Flagged content prepared for UI: {len(flagged_content)} characters")
|
| 1907 |
+
else:
|
| 1908 |
+
print("🏷️ No flagged content to display")
|
| 1909 |
+
|
| 1910 |
# Return results with hidden error modal (validation passed)
|
| 1911 |
success_timestamp = int(time.time() * 1000)
|
| 1912 |
print(f"✅ Validation successful at {success_timestamp} - hiding error modal")
|
| 1913 |
return (final_out, results_html, exec_log, updated_emails[0], updated_emails[1], updated_emails[2],
|
| 1914 |
+
gr.update(value="", visible=False), # Hide error modal
|
| 1915 |
+
gr.update(visible=flagged_accordion_visible, open=flagged_accordion_open), # Update flagged accordion
|
| 1916 |
+
gr.update(value=flagged_display_html)) # Update flagged content
|
| 1917 |
|
| 1918 |
submit_btn.click(
|
| 1919 |
fn=submit_and_update,
|
| 1920 |
inputs=[attack_from, attack_subject, attack_body, model_selector, defense_toggle, user_info],
|
| 1921 |
+
outputs=[final_output_display, results_display, trace_display, email1_display, email2_display, email3_display, error_modal_html, flagged_accordion, flagged_content_display]
|
| 1922 |
)
|
| 1923 |
|
| 1924 |
# Connect dismiss trigger to properly hide the modal
|
instruction_classifier.py
CHANGED
|
@@ -186,7 +186,7 @@ class InstructionClassifierSanitizer:
|
|
| 186 |
|
| 187 |
|
| 188 |
@spaces.GPU
|
| 189 |
-
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
|
| 190 |
"""
|
| 191 |
Sanitization function that also returns annotation data for flagged content.
|
| 192 |
|
|
@@ -194,11 +194,13 @@ class InstructionClassifierSanitizer:
|
|
| 194 |
tool_output: The raw tool output string
|
| 195 |
|
| 196 |
Returns:
|
| 197 |
-
Tuple of (sanitized_output, annotations) where
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
"""
|
| 200 |
if not tool_output or not tool_output.strip():
|
| 201 |
-
return tool_output, []
|
| 202 |
|
| 203 |
# Move model to target device (GPU) within @spaces.GPU decorated method
|
| 204 |
if self.device != self.target_device:
|
|
@@ -214,7 +216,7 @@ class InstructionClassifierSanitizer:
|
|
| 214 |
|
| 215 |
if not is_injection:
|
| 216 |
print("✅ No injection detected - returning original output")
|
| 217 |
-
return tool_output, []
|
| 218 |
|
| 219 |
print(f"🚨 Injection detected! Processing with extensions and annotations...")
|
| 220 |
|
|
@@ -233,12 +235,13 @@ class InstructionClassifierSanitizer:
|
|
| 233 |
# Step 5: Remove instruction tags and their content
|
| 234 |
sanitized_output = self._remove_instruction_tags(merged_tagged_text)
|
| 235 |
print(f"🔒 Sanitized output: {sanitized_output}")
|
| 236 |
-
|
|
|
|
| 237 |
|
| 238 |
except Exception as e:
|
| 239 |
print(f"❌ Error in instruction classifier sanitization: {e}")
|
| 240 |
# Return original output if sanitization fails
|
| 241 |
-
return tool_output, []
|
| 242 |
|
| 243 |
def _extract_annotations_from_tagged_text(self, tagged_text: str, original_text: str) -> List[Dict[str, any]]:
|
| 244 |
"""
|
|
@@ -700,23 +703,25 @@ def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
|
|
| 700 |
defense_enabled: Whether defense is enabled (passed from agent)
|
| 701 |
|
| 702 |
Returns:
|
| 703 |
-
Tuple of (sanitized_output, annotations) where
|
| 704 |
-
|
|
|
|
|
|
|
| 705 |
"""
|
| 706 |
print(f"🔍 sanitize_tool_output_with_annotations called with: {tool_output[:100]}...")
|
| 707 |
|
| 708 |
# If defense is disabled globally, return original output with no annotations
|
| 709 |
if not defense_enabled:
|
| 710 |
print("⚠️ Defense disabled - returning original output without processing")
|
| 711 |
-
return tool_output, []
|
| 712 |
|
| 713 |
sanitizer = get_sanitizer()
|
| 714 |
if sanitizer is None:
|
| 715 |
print("⚠️ Instruction classifier not available, returning original output")
|
| 716 |
-
return tool_output, []
|
| 717 |
|
| 718 |
print("✅ Sanitizer found, processing with annotations...")
|
| 719 |
-
sanitized_output, annotations = sanitizer.sanitize_with_annotations(tool_output)
|
| 720 |
print(f"🔒 Sanitization complete, result: {sanitized_output[:100]}...")
|
| 721 |
print(f"📝 Found {len(annotations)} annotations")
|
| 722 |
-
return sanitized_output, annotations
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
@spaces.GPU
|
| 189 |
+
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]], str]:
|
| 190 |
"""
|
| 191 |
Sanitization function that also returns annotation data for flagged content.
|
| 192 |
|
|
|
|
| 194 |
tool_output: The raw tool output string
|
| 195 |
|
| 196 |
Returns:
|
| 197 |
+
Tuple of (sanitized_output, annotations, merged_tagged_text) where:
|
| 198 |
+
- sanitized_output: cleaned text with instruction content removed
|
| 199 |
+
- annotations: position information for content flagged by classifier
|
| 200 |
+
- merged_tagged_text: text with <instruction> tags showing detected content
|
| 201 |
"""
|
| 202 |
if not tool_output or not tool_output.strip():
|
| 203 |
+
return tool_output, [], tool_output
|
| 204 |
|
| 205 |
# Move model to target device (GPU) within @spaces.GPU decorated method
|
| 206 |
if self.device != self.target_device:
|
|
|
|
| 216 |
|
| 217 |
if not is_injection:
|
| 218 |
print("✅ No injection detected - returning original output")
|
| 219 |
+
return tool_output, [], tool_output
|
| 220 |
|
| 221 |
print(f"🚨 Injection detected! Processing with extensions and annotations...")
|
| 222 |
|
|
|
|
| 235 |
# Step 5: Remove instruction tags and their content
|
| 236 |
sanitized_output = self._remove_instruction_tags(merged_tagged_text)
|
| 237 |
print(f"🔒 Sanitized output: {sanitized_output}")
|
| 238 |
+
print(f"🔍 DEBUG SANITIZER: Returning merged_tagged_text: '{merged_tagged_text}'")
|
| 239 |
+
return sanitized_output, annotations, merged_tagged_text
|
| 240 |
|
| 241 |
except Exception as e:
|
| 242 |
print(f"❌ Error in instruction classifier sanitization: {e}")
|
| 243 |
# Return original output if sanitization fails
|
| 244 |
+
return tool_output, [], tool_output
|
| 245 |
|
| 246 |
def _extract_annotations_from_tagged_text(self, tagged_text: str, original_text: str) -> List[Dict[str, any]]:
|
| 247 |
"""
|
|
|
|
| 703 |
defense_enabled: Whether defense is enabled (passed from agent)
|
| 704 |
|
| 705 |
Returns:
|
| 706 |
+
Tuple of (sanitized_output, annotations, merged_tagged_text) where:
|
| 707 |
+
- sanitized_output: cleaned text with instruction content removed
|
| 708 |
+
- annotations: position information for content flagged by classifier
|
| 709 |
+
- merged_tagged_text: text with <instruction> tags showing detected content
|
| 710 |
"""
|
| 711 |
print(f"🔍 sanitize_tool_output_with_annotations called with: {tool_output[:100]}...")
|
| 712 |
|
| 713 |
# If defense is disabled globally, return original output with no annotations
|
| 714 |
if not defense_enabled:
|
| 715 |
print("⚠️ Defense disabled - returning original output without processing")
|
| 716 |
+
return tool_output, [], tool_output
|
| 717 |
|
| 718 |
sanitizer = get_sanitizer()
|
| 719 |
if sanitizer is None:
|
| 720 |
print("⚠️ Instruction classifier not available, returning original output")
|
| 721 |
+
return tool_output, [], tool_output
|
| 722 |
|
| 723 |
print("✅ Sanitizer found, processing with annotations...")
|
| 724 |
+
sanitized_output, annotations, merged_tagged_text = sanitizer.sanitize_with_annotations(tool_output)
|
| 725 |
print(f"🔒 Sanitization complete, result: {sanitized_output[:100]}...")
|
| 726 |
print(f"📝 Found {len(annotations)} annotations")
|
| 727 |
+
return sanitized_output, annotations, merged_tagged_text
|
utils.py
CHANGED
|
@@ -1,147 +1,19 @@
|
|
| 1 |
import json
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
-
from torch.utils.data import Dataset
|
| 5 |
-
from transformers import
|
| 6 |
-
import numpy as np
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
import re
|
| 9 |
-
from typing import List,
|
| 10 |
import warnings
|
| 11 |
import logging
|
| 12 |
import os
|
| 13 |
-
from datetime import datetime
|
| 14 |
-
from sklearn.utils.class_weight import compute_class_weight
|
| 15 |
-
import torch.nn.functional as F
|
| 16 |
|
| 17 |
# Disable tokenizer parallelism to avoid forking warnings
|
| 18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 19 |
|
| 20 |
warnings.filterwarnings('ignore')
|
| 21 |
|
| 22 |
-
def set_random_seeds(seed=42):
|
| 23 |
-
"""Set random seeds for reproducibility"""
|
| 24 |
-
import random
|
| 25 |
-
import numpy as np
|
| 26 |
-
import torch
|
| 27 |
-
|
| 28 |
-
random.seed(seed)
|
| 29 |
-
np.random.seed(seed)
|
| 30 |
-
torch.manual_seed(seed)
|
| 31 |
-
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
| 32 |
-
|
| 33 |
-
# Make CuDNN deterministic (slower but reproducible)
|
| 34 |
-
torch.backends.cudnn.deterministic = True
|
| 35 |
-
torch.backends.cudnn.benchmark = False
|
| 36 |
-
|
| 37 |
-
def setup_logging(log_dir='data/logs'):
|
| 38 |
-
"""Setup logging configuration"""
|
| 39 |
-
# Create logs directory if it doesn't exist
|
| 40 |
-
os.makedirs(log_dir, exist_ok=True)
|
| 41 |
-
|
| 42 |
-
# Create timestamp for log file
|
| 43 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 44 |
-
log_file = os.path.join(log_dir, f'training_log_{timestamp}.log')
|
| 45 |
-
|
| 46 |
-
# Configure logging
|
| 47 |
-
logging.basicConfig(
|
| 48 |
-
level=logging.INFO, # Back to INFO level
|
| 49 |
-
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 50 |
-
handlers=[
|
| 51 |
-
logging.FileHandler(log_file),
|
| 52 |
-
logging.StreamHandler() # Also print to console
|
| 53 |
-
]
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
logger = logging.getLogger(__name__)
|
| 57 |
-
logger.info(f"Logging initialized. Log file: {log_file}")
|
| 58 |
-
return logger, log_file
|
| 59 |
-
|
| 60 |
-
def check_gpu_availability():
|
| 61 |
-
"""Check and print GPU availability information"""
|
| 62 |
-
logger = logging.getLogger(__name__)
|
| 63 |
-
logger.info("=== GPU Availability Check ===")
|
| 64 |
-
|
| 65 |
-
if torch.backends.mps.is_available():
|
| 66 |
-
logger.info("✓ MPS (Apple Silicon GPU) is available")
|
| 67 |
-
if torch.backends.mps.is_built():
|
| 68 |
-
logger.info("✓ MPS is built into PyTorch")
|
| 69 |
-
else:
|
| 70 |
-
logger.info("✗ MPS is not built into PyTorch")
|
| 71 |
-
else:
|
| 72 |
-
logger.info("✗ MPS (Apple Silicon GPU) is not available")
|
| 73 |
-
|
| 74 |
-
if torch.cuda.is_available():
|
| 75 |
-
logger.info(f"✓ CUDA is available (GPU count: {torch.cuda.device_count()})")
|
| 76 |
-
else:
|
| 77 |
-
logger.info("✗ CUDA is not available")
|
| 78 |
-
|
| 79 |
-
logger.info(f"PyTorch version: {torch.__version__}")
|
| 80 |
-
logger.info("=" * 50)
|
| 81 |
-
|
| 82 |
-
def calculate_class_weights(dataset):
|
| 83 |
-
"""Calculate class weights for imbalanced dataset using BERT paper approach"""
|
| 84 |
-
logger = logging.getLogger(__name__)
|
| 85 |
-
|
| 86 |
-
# Collect all labels from the dataset (BERT approach: only first subtokens have real labels)
|
| 87 |
-
all_labels = []
|
| 88 |
-
for window_data in dataset.processed_data:
|
| 89 |
-
# Filter out -100 labels (special tokens + subsequent subtokens of same word)
|
| 90 |
-
# This gives us true word-level class distribution
|
| 91 |
-
valid_labels = [label for label in window_data['subword_labels'] if label != -100]
|
| 92 |
-
all_labels.extend(valid_labels)
|
| 93 |
-
|
| 94 |
-
# Convert to numpy array
|
| 95 |
-
y = np.array(all_labels)
|
| 96 |
-
|
| 97 |
-
# Calculate class weights using sklearn
|
| 98 |
-
classes = np.unique(y)
|
| 99 |
-
class_weights = compute_class_weight('balanced', classes=classes, y=y)
|
| 100 |
-
|
| 101 |
-
# Create weight tensor
|
| 102 |
-
weight_tensor = torch.FloatTensor(class_weights)
|
| 103 |
-
|
| 104 |
-
logger.info(f"Word-level class distribution: {np.bincount(y)}")
|
| 105 |
-
logger.info(f"Class 0 (Non-instruction words): {np.sum(y == 0)} words ({np.sum(y == 0)/len(y)*100:.1f}%)")
|
| 106 |
-
logger.info(f"Class 1 (Instruction words): {np.sum(y == 1)} words ({np.sum(y == 1)/len(y)*100:.1f}%)")
|
| 107 |
-
logger.info(f"Calculated class weights (word-level): {class_weights}")
|
| 108 |
-
logger.info(f" Weight for class 0 (Non-instruction): {class_weights[0]:.4f}")
|
| 109 |
-
logger.info(f" Weight for class 1 (Instruction): {class_weights[1]:.4f}")
|
| 110 |
-
|
| 111 |
-
return weight_tensor
|
| 112 |
-
|
| 113 |
-
class FocalLoss(nn.Module):
|
| 114 |
-
"""Focal Loss for addressing class imbalance"""
|
| 115 |
-
def __init__(self, alpha=1, gamma=2, ignore_index=-100):
|
| 116 |
-
super(FocalLoss, self).__init__()
|
| 117 |
-
self.alpha = alpha
|
| 118 |
-
self.gamma = gamma
|
| 119 |
-
self.ignore_index = ignore_index
|
| 120 |
-
|
| 121 |
-
def forward(self, inputs, targets):
|
| 122 |
-
# Flatten inputs and targets
|
| 123 |
-
inputs = inputs.view(-1, inputs.size(-1))
|
| 124 |
-
targets = targets.view(-1)
|
| 125 |
-
|
| 126 |
-
# Create mask for non-ignored indices
|
| 127 |
-
mask = targets != self.ignore_index
|
| 128 |
-
targets = targets[mask]
|
| 129 |
-
inputs = inputs[mask]
|
| 130 |
-
|
| 131 |
-
if len(targets) == 0:
|
| 132 |
-
return torch.tensor(0.0, requires_grad=True, device=inputs.device)
|
| 133 |
-
|
| 134 |
-
# Calculate cross entropy
|
| 135 |
-
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
| 136 |
-
|
| 137 |
-
# Calculate pt
|
| 138 |
-
pt = torch.exp(-ce_loss)
|
| 139 |
-
|
| 140 |
-
# Calculate focal loss
|
| 141 |
-
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
|
| 142 |
-
|
| 143 |
-
return focal_loss.mean()
|
| 144 |
-
|
| 145 |
class InstructionDataset(Dataset):
|
| 146 |
def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
|
| 147 |
window_size: int = 512, overlap: int = 100):
|
|
@@ -517,8 +389,6 @@ class TransformerInstructionClassifier(nn.Module):
|
|
| 517 |
# Setup loss function based on type
|
| 518 |
if loss_type == 'weighted_ce':
|
| 519 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
|
| 520 |
-
elif loss_type == 'focal':
|
| 521 |
-
self.loss_fct = FocalLoss(alpha=1, gamma=2, ignore_index=-100)
|
| 522 |
else:
|
| 523 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 524 |
|
|
|
|
| 1 |
import json
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from transformers import AutoModel
|
|
|
|
|
|
|
| 6 |
import re
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
import warnings
|
| 9 |
import logging
|
| 10 |
import os
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Disable tokenizer parallelism to avoid forking warnings
|
| 13 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
|
| 15 |
warnings.filterwarnings('ignore')
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
class InstructionDataset(Dataset):
|
| 18 |
def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
|
| 19 |
window_size: int = 512, overlap: int = 100):
|
|
|
|
| 389 |
# Setup loss function based on type
|
| 390 |
if loss_type == 'weighted_ce':
|
| 391 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
|
|
|
|
|
|
|
| 392 |
else:
|
| 393 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 394 |
|