Spaces:
Running
on
Zero
Running
on
Zero
UI tagged preview added
Browse files- agent.py +67 -28
- app.py +107 -16
- instruction_classifier.py +18 -13
- utils.py +3 -133
agent.py
CHANGED
@@ -551,26 +551,6 @@ Body: {email.body_value}"""
|
|
551 |
# Import the instruction classifier sanitizer
|
552 |
from instruction_classifier import sanitize_tool_output_with_annotations
|
553 |
|
554 |
-
|
555 |
-
def extract_tool_calls(text):
|
556 |
-
"""Extract tool calls from LLM output (legacy function - kept for compatibility)"""
|
557 |
-
tool_calls = []
|
558 |
-
|
559 |
-
# Patterns to match tool calls
|
560 |
-
patterns = [
|
561 |
-
r'get_emails\(\)',
|
562 |
-
r'search_email\(keyword=[^)]*\)', # search_email(keyword="UBS")
|
563 |
-
r'search_email\(\s*"[^"]+"\s*\)', # search_email("UBS")
|
564 |
-
r'send_email\([^)]+\)'
|
565 |
-
]
|
566 |
-
|
567 |
-
for pattern in patterns:
|
568 |
-
matches = re.findall(pattern, text)
|
569 |
-
tool_calls.extend(matches)
|
570 |
-
|
571 |
-
return tool_calls
|
572 |
-
|
573 |
-
|
574 |
def extract_and_parse_tool_calls(text):
|
575 |
"""
|
576 |
Extract tool calls from LLM output and parse them into structured format
|
@@ -714,6 +694,38 @@ def create_assistant_message_with_tool_calls(llm_output, parsed_tool_calls, prov
|
|
714 |
return {"role": "assistant", "content": llm_output}
|
715 |
|
716 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
717 |
def create_tool_result_message(tool_results, provider):
|
718 |
"""
|
719 |
Create properly formatted tool result message based on LLM provider
|
@@ -843,6 +855,9 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
843 |
# Track annotations for instruction classifier flagged content
|
844 |
all_annotations = []
|
845 |
|
|
|
|
|
|
|
846 |
# Initialize conversation with system prompt and user query
|
847 |
# This will be used for LLM API calls (provider-specific format)
|
848 |
llm_messages = [
|
@@ -968,7 +983,17 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
968 |
# Conditional sanitization based on defense setting
|
969 |
if defense_enabled:
|
970 |
# Sanitize tool output with annotations
|
971 |
-
sanitized_output, annotations = sanitize_tool_output_with_annotations(tool_output, defense_enabled)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
972 |
|
973 |
# Always add raw tool output to trace when defense is enabled
|
974 |
raw_tool_message = {
|
@@ -988,17 +1013,26 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
988 |
|
989 |
# Add annotations to our collection
|
990 |
all_annotations.extend(annotations)
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
|
|
|
|
995 |
|
996 |
-
#
|
997 |
sanitized_msg = f"🔒 Sanitized Result: {sanitized_output}"
|
998 |
execution_log.append(sanitized_msg)
|
999 |
print(sanitized_msg)
|
1000 |
|
1001 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
sanitized_tool_message = {
|
1003 |
"role": "tool",
|
1004 |
"tool_call_id": tool_call_info['id'],
|
@@ -1144,4 +1178,9 @@ def tool_agent_loop(user_query, inbox, system_prompt, model_name="gpt-4o-mini",
|
|
1144 |
final_trace_msg = f"📊 Trace push completed (with {len(all_annotations)} annotations)"
|
1145 |
execution_log.append(final_trace_msg)
|
1146 |
|
1147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
551 |
# Import the instruction classifier sanitizer
|
552 |
from instruction_classifier import sanitize_tool_output_with_annotations
|
553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
def extract_and_parse_tool_calls(text):
|
555 |
"""
|
556 |
Extract tool calls from LLM output and parse them into structured format
|
|
|
694 |
return {"role": "assistant", "content": llm_output}
|
695 |
|
696 |
|
697 |
+
def process_flagged_content(merged_tagged_text):
|
698 |
+
"""
|
699 |
+
Process tagged text to find second 'To: [email protected]' and trim before it.
|
700 |
+
|
701 |
+
Args:
|
702 |
+
merged_tagged_text (str): Text with <instruction> tags from classifier
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
str: Processed text with content before second occurrence removed
|
706 |
+
"""
|
707 |
+
if not merged_tagged_text:
|
708 |
+
return ""
|
709 |
+
|
710 |
+
# Find second occurrence of "To: [email protected]"
|
711 |
+
target_phrase = "To: [email protected]"
|
712 |
+
first_pos = merged_tagged_text.find(target_phrase)
|
713 |
+
if first_pos != -1:
|
714 |
+
second_pos = merged_tagged_text.find(target_phrase, first_pos + 1)
|
715 |
+
if second_pos != -1:
|
716 |
+
# Remove everything before and including the second occurrence
|
717 |
+
processed_text = merged_tagged_text[second_pos + len(target_phrase):].strip()
|
718 |
+
print(f"🏷️ Found second occurrence at position {second_pos}, processed flagged content: {processed_text[:100]}...")
|
719 |
+
# Insert newline before "Time:" and "Body:" (but not if already at start of line)
|
720 |
+
processed_text = re.sub(r'(?<!\n)(Time:)', r'\n\1', processed_text)
|
721 |
+
processed_text = re.sub(r'(?<!\n)(Body:)', r'\n\1', processed_text)
|
722 |
+
return processed_text
|
723 |
+
|
724 |
+
# If no second occurrence, return entire text
|
725 |
+
print(f"🏷️ No second occurrence found, returning entire flagged content: {merged_tagged_text[:100]}...")
|
726 |
+
return merged_tagged_text
|
727 |
+
|
728 |
+
|
729 |
def create_tool_result_message(tool_results, provider):
|
730 |
"""
|
731 |
Create properly formatted tool result message based on LLM provider
|
|
|
855 |
# Track annotations for instruction classifier flagged content
|
856 |
all_annotations = []
|
857 |
|
858 |
+
# Track flagged content for UI display
|
859 |
+
all_flagged_content = []
|
860 |
+
|
861 |
# Initialize conversation with system prompt and user query
|
862 |
# This will be used for LLM API calls (provider-specific format)
|
863 |
llm_messages = [
|
|
|
983 |
# Conditional sanitization based on defense setting
|
984 |
if defense_enabled:
|
985 |
# Sanitize tool output with annotations
|
986 |
+
sanitized_output, annotations, merged_tagged_text = sanitize_tool_output_with_annotations(tool_output, defense_enabled)
|
987 |
+
|
988 |
+
# Process and collect flagged content for UI display
|
989 |
+
print(f"🔍 DEBUG: merged_tagged_text: {merged_tagged_text}")
|
990 |
+
print(f"🔍 DEBUG: has <instruction> tags: {'<instruction>' in merged_tagged_text if merged_tagged_text else 'No text'}")
|
991 |
+
if merged_tagged_text and merged_tagged_text.strip() and "<instruction>" in merged_tagged_text:
|
992 |
+
processed_flagged = process_flagged_content(merged_tagged_text)
|
993 |
+
print(f"🔍 DEBUG: processed_flagged result: {processed_flagged}")
|
994 |
+
if processed_flagged:
|
995 |
+
all_flagged_content.append(processed_flagged)
|
996 |
+
print(f"🔍 DEBUG: Added to all_flagged_content. Total items: {len(all_flagged_content)}")
|
997 |
|
998 |
# Always add raw tool output to trace when defense is enabled
|
999 |
raw_tool_message = {
|
|
|
1013 |
|
1014 |
# Add annotations to our collection
|
1015 |
all_annotations.extend(annotations)
|
1016 |
+
|
1017 |
+
|
1018 |
+
# Add some spacing before sanitized output for clarity
|
1019 |
+
execution_log.append("")
|
1020 |
+
execution_log.append("--- DEFENSE PROCESSING ---")
|
1021 |
+
execution_log.append("")
|
1022 |
|
1023 |
+
# Show sanitized result in logs when defense is enabled
|
1024 |
sanitized_msg = f"🔒 Sanitized Result: {sanitized_output}"
|
1025 |
execution_log.append(sanitized_msg)
|
1026 |
print(sanitized_msg)
|
1027 |
|
1028 |
+
# Add spacing separator in trace for clarity
|
1029 |
+
separator_message = {
|
1030 |
+
"role": "system",
|
1031 |
+
"content": "--- DEFENSE SANITIZATION APPLIED ---"
|
1032 |
+
}
|
1033 |
+
trace_messages.append(separator_message)
|
1034 |
+
|
1035 |
+
# Add sanitized tool output to trace when defense is enabled
|
1036 |
sanitized_tool_message = {
|
1037 |
"role": "tool",
|
1038 |
"tool_call_id": tool_call_info['id'],
|
|
|
1178 |
final_trace_msg = f"📊 Trace push completed (with {len(all_annotations)} annotations)"
|
1179 |
execution_log.append(final_trace_msg)
|
1180 |
|
1181 |
+
# Combine all flagged content for UI display
|
1182 |
+
combined_flagged_content = "\n\n".join(all_flagged_content) if all_flagged_content else ""
|
1183 |
+
print(f"🔍 DEBUG: Final combined_flagged_content: '{combined_flagged_content}'")
|
1184 |
+
print(f"🔍 DEBUG: Length: {len(combined_flagged_content)} characters")
|
1185 |
+
|
1186 |
+
return "\n".join(execution_log), llm_output, combined_flagged_content
|
app.py
CHANGED
@@ -526,13 +526,6 @@ def is_likely_gibberish_soft(text):
|
|
526 |
|
527 |
return False # Passes soft gibberish checks
|
528 |
|
529 |
-
def validate_english_with_model_loading(text):
|
530 |
-
"""
|
531 |
-
Convenience function that handles FastText model loading automatically.
|
532 |
-
"""
|
533 |
-
model = load_fasttext_model() # This will download and load the model if needed
|
534 |
-
return validate_english_only_windowed(text, model)
|
535 |
-
|
536 |
def get_fasttext_confidence_scores(text, model=None, top_k=3):
|
537 |
"""
|
538 |
Get top language confidence scores from FastText without doing validation.
|
@@ -761,7 +754,7 @@ def submit_attack(from_addr, attack_subject, attack_body, model_name="gpt-4o", d
|
|
761 |
}
|
762 |
|
763 |
# Process the fixed user query with the tool agent loop
|
764 |
-
execution_log, final_output = tool_agent_loop(
|
765 |
user_query=USER_INPUT,
|
766 |
inbox=INBOX,
|
767 |
system_prompt=SYSTEM_PROMPT,
|
@@ -771,13 +764,13 @@ def submit_attack(from_addr, attack_subject, attack_body, model_name="gpt-4o", d
|
|
771 |
fasttext_confidence_scores=fasttext_confidence_scores
|
772 |
)
|
773 |
|
774 |
-
# Return execution log
|
775 |
-
return execution_log, final_output
|
776 |
|
777 |
except Exception as e:
|
778 |
error_msg = f"❌ Error processing attack: {str(e)}"
|
779 |
print(error_msg)
|
780 |
-
return "", error_msg
|
781 |
|
782 |
def reset_to_initial_state():
|
783 |
"""Reset the inbox to original state and clear all inputs"""
|
@@ -1175,6 +1168,74 @@ def create_interface():
|
|
1175 |
.results-card ul { margin: 0; padding-left: 16px; }
|
1176 |
.results-card li { margin: 4px 0; }
|
1177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1178 |
/* Error Modal Popup Styling */
|
1179 |
.error-modal-overlay {
|
1180 |
position: fixed !important;
|
@@ -1512,6 +1573,14 @@ Satya
|
|
1512 |
)
|
1513 |
# Attack results summary (pretty list)
|
1514 |
results_display = gr.HTML("", elem_id="attack-results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1515 |
with gr.Accordion("Show Execution Trace", open=False):
|
1516 |
trace_display = gr.Textbox(
|
1517 |
lines=14,
|
@@ -1705,7 +1774,9 @@ Satya
|
|
1705 |
gr.update(), # email1_display - no change
|
1706 |
gr.update(), # email2_display - no change
|
1707 |
gr.update(), # email3_display - no change
|
1708 |
-
gr.update(value=modal_html, visible=True) # error_modal_html
|
|
|
|
|
1709 |
)
|
1710 |
|
1711 |
print("✅ ALL VALIDATION PASSED - proceeding with attack submission")
|
@@ -1717,7 +1788,7 @@ Satya
|
|
1717 |
}
|
1718 |
|
1719 |
try:
|
1720 |
-
exec_log, final_out = submit_attack(from_addr.strip(), subject, body, model, defense_enabled, user_info.strip(), confidence_scores)
|
1721 |
except Exception as e:
|
1722 |
# Handle any setup or execution errors with detailed messages
|
1723 |
error_str = str(e).lower()
|
@@ -1773,7 +1844,9 @@ Satya
|
|
1773 |
gr.update(), # email1_display - no change
|
1774 |
gr.update(), # email2_display - no change
|
1775 |
gr.update(), # email3_display - no change
|
1776 |
-
gr.update(value=modal_html, visible=True) # error_modal_html
|
|
|
|
|
1777 |
)
|
1778 |
|
1779 |
# Build a formatted results summary extracted from exec_log
|
@@ -1818,16 +1891,34 @@ Satya
|
|
1818 |
for i, email in enumerate(emails_to_display):
|
1819 |
updated_emails.append(format_single_email(email, i + 1))
|
1820 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1821 |
# Return results with hidden error modal (validation passed)
|
1822 |
success_timestamp = int(time.time() * 1000)
|
1823 |
print(f"✅ Validation successful at {success_timestamp} - hiding error modal")
|
1824 |
return (final_out, results_html, exec_log, updated_emails[0], updated_emails[1], updated_emails[2],
|
1825 |
-
gr.update(value="", visible=False)
|
|
|
|
|
1826 |
|
1827 |
submit_btn.click(
|
1828 |
fn=submit_and_update,
|
1829 |
inputs=[attack_from, attack_subject, attack_body, model_selector, defense_toggle, user_info],
|
1830 |
-
outputs=[final_output_display, results_display, trace_display, email1_display, email2_display, email3_display, error_modal_html]
|
1831 |
)
|
1832 |
|
1833 |
# Connect dismiss trigger to properly hide the modal
|
|
|
526 |
|
527 |
return False # Passes soft gibberish checks
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
def get_fasttext_confidence_scores(text, model=None, top_k=3):
|
530 |
"""
|
531 |
Get top language confidence scores from FastText without doing validation.
|
|
|
754 |
}
|
755 |
|
756 |
# Process the fixed user query with the tool agent loop
|
757 |
+
execution_log, final_output, flagged_content = tool_agent_loop(
|
758 |
user_query=USER_INPUT,
|
759 |
inbox=INBOX,
|
760 |
system_prompt=SYSTEM_PROMPT,
|
|
|
764 |
fasttext_confidence_scores=fasttext_confidence_scores
|
765 |
)
|
766 |
|
767 |
+
# Return execution log, final output, and flagged content separately
|
768 |
+
return execution_log, final_output, flagged_content
|
769 |
|
770 |
except Exception as e:
|
771 |
error_msg = f"❌ Error processing attack: {str(e)}"
|
772 |
print(error_msg)
|
773 |
+
return "", error_msg, ""
|
774 |
|
775 |
def reset_to_initial_state():
|
776 |
"""Reset the inbox to original state and clear all inputs"""
|
|
|
1168 |
.results-card ul { margin: 0; padding-left: 16px; }
|
1169 |
.results-card li { margin: 4px 0; }
|
1170 |
|
1171 |
+
|
1172 |
+
|
1173 |
+
/* Accordion content styling for flagged content */
|
1174 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"]) {
|
1175 |
+
max-height: 300px !important;
|
1176 |
+
overflow-y: auto !important;
|
1177 |
+
padding: 16px !important;
|
1178 |
+
background: white !important;
|
1179 |
+
border-radius: 8px !important;
|
1180 |
+
font-family: 'Roboto', sans-serif !important;
|
1181 |
+
line-height: 1.6 !important;
|
1182 |
+
color: #333333 !important;
|
1183 |
+
word-wrap: break-word !important;
|
1184 |
+
overflow-wrap: break-word !important;
|
1185 |
+
scrollbar-width: thin !important;
|
1186 |
+
}
|
1187 |
+
|
1188 |
+
/* Scrollbar styling for accordion content */
|
1189 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar {
|
1190 |
+
width: 8px !important;
|
1191 |
+
}
|
1192 |
+
|
1193 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-track {
|
1194 |
+
background: rgba(0,0,0,0.1) !important;
|
1195 |
+
border-radius: 4px !important;
|
1196 |
+
}
|
1197 |
+
|
1198 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-thumb {
|
1199 |
+
background: rgba(0,0,0,0.3) !important;
|
1200 |
+
border-radius: 4px !important;
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"])::-webkit-scrollbar-thumb:hover {
|
1204 |
+
background: rgba(0,0,0,0.5) !important;
|
1205 |
+
}
|
1206 |
+
|
1207 |
+
/* Instruction tag styling for light mode */
|
1208 |
+
instruction {
|
1209 |
+
background-color: #ffebee !important;
|
1210 |
+
color: #c62828 !important;
|
1211 |
+
padding: 2px 6px !important;
|
1212 |
+
border-radius: 4px !important;
|
1213 |
+
font-weight: 600 !important;
|
1214 |
+
border: 1px solid #ef5350 !important;
|
1215 |
+
box-shadow: 0 1px 2px rgba(198, 40, 40, 0.2) !important;
|
1216 |
+
display: inline !important;
|
1217 |
+
font-family: 'Roboto', sans-serif !important;
|
1218 |
+
font-size: 14px !important;
|
1219 |
+
line-height: 1.4 !important;
|
1220 |
+
margin: 0 2px !important;
|
1221 |
+
}
|
1222 |
+
|
1223 |
+
/* Instruction tag styling for dark mode */
|
1224 |
+
@media (prefers-color-scheme: dark) {
|
1225 |
+
instruction {
|
1226 |
+
background-color: rgb(84 37 37) !important;
|
1227 |
+
color: #ffffff !important;
|
1228 |
+
border: 1px solid #d32f2f !important;
|
1229 |
+
box-shadow: 0 1px 3px rgba(183, 28, 28, 0.4) !important;
|
1230 |
+
}
|
1231 |
+
|
1232 |
+
/* Also ensure accordion content has proper dark mode styling */
|
1233 |
+
.gr-accordion .gr-panel:has([data-testid="HTML"]) {
|
1234 |
+
background: var(--background-fill-primary) !important;
|
1235 |
+
color: var(--body-text-color) !important;
|
1236 |
+
}
|
1237 |
+
}
|
1238 |
+
|
1239 |
/* Error Modal Popup Styling */
|
1240 |
.error-modal-overlay {
|
1241 |
position: fixed !important;
|
|
|
1573 |
)
|
1574 |
# Attack results summary (pretty list)
|
1575 |
results_display = gr.HTML("", elem_id="attack-results")
|
1576 |
+
|
1577 |
+
# Flagged content display (only shown when defense enabled and content found)
|
1578 |
+
with gr.Accordion("Show What was Flagged", open=False, visible=False) as flagged_accordion:
|
1579 |
+
flagged_content_display = gr.HTML(
|
1580 |
+
"",
|
1581 |
+
show_label=False
|
1582 |
+
)
|
1583 |
+
|
1584 |
with gr.Accordion("Show Execution Trace", open=False):
|
1585 |
trace_display = gr.Textbox(
|
1586 |
lines=14,
|
|
|
1774 |
gr.update(), # email1_display - no change
|
1775 |
gr.update(), # email2_display - no change
|
1776 |
gr.update(), # email3_display - no change
|
1777 |
+
gr.update(value=modal_html, visible=True), # error_modal_html
|
1778 |
+
gr.update(), # flagged_accordion - no change
|
1779 |
+
gr.update() # flagged_content_display - no change
|
1780 |
)
|
1781 |
|
1782 |
print("✅ ALL VALIDATION PASSED - proceeding with attack submission")
|
|
|
1788 |
}
|
1789 |
|
1790 |
try:
|
1791 |
+
exec_log, final_out, flagged_content = submit_attack(from_addr.strip(), subject, body, model, defense_enabled, user_info.strip(), confidence_scores)
|
1792 |
except Exception as e:
|
1793 |
# Handle any setup or execution errors with detailed messages
|
1794 |
error_str = str(e).lower()
|
|
|
1844 |
gr.update(), # email1_display - no change
|
1845 |
gr.update(), # email2_display - no change
|
1846 |
gr.update(), # email3_display - no change
|
1847 |
+
gr.update(value=modal_html, visible=True), # error_modal_html
|
1848 |
+
gr.update(), # flagged_accordion - no change
|
1849 |
+
gr.update() # flagged_content_display - no change
|
1850 |
)
|
1851 |
|
1852 |
# Build a formatted results summary extracted from exec_log
|
|
|
1891 |
for i, email in enumerate(emails_to_display):
|
1892 |
updated_emails.append(format_single_email(email, i + 1))
|
1893 |
|
1894 |
+
# Process flagged content for display
|
1895 |
+
flagged_display_html = ""
|
1896 |
+
flagged_accordion_visible = False
|
1897 |
+
flagged_accordion_open = False
|
1898 |
+
|
1899 |
+
if defense_enabled and flagged_content and flagged_content.strip():
|
1900 |
+
# Convert newlines to HTML line breaks for proper rendering
|
1901 |
+
flagged_content_html = flagged_content.replace('\n', '<br>')
|
1902 |
+
# Simple HTML structure without extra containers
|
1903 |
+
flagged_display_html = flagged_content_html
|
1904 |
+
flagged_accordion_visible = True
|
1905 |
+
flagged_accordion_open = True # Open after submit when there's content
|
1906 |
+
print(f"🏷️ Flagged content prepared for UI: {len(flagged_content)} characters")
|
1907 |
+
else:
|
1908 |
+
print("🏷️ No flagged content to display")
|
1909 |
+
|
1910 |
# Return results with hidden error modal (validation passed)
|
1911 |
success_timestamp = int(time.time() * 1000)
|
1912 |
print(f"✅ Validation successful at {success_timestamp} - hiding error modal")
|
1913 |
return (final_out, results_html, exec_log, updated_emails[0], updated_emails[1], updated_emails[2],
|
1914 |
+
gr.update(value="", visible=False), # Hide error modal
|
1915 |
+
gr.update(visible=flagged_accordion_visible, open=flagged_accordion_open), # Update flagged accordion
|
1916 |
+
gr.update(value=flagged_display_html)) # Update flagged content
|
1917 |
|
1918 |
submit_btn.click(
|
1919 |
fn=submit_and_update,
|
1920 |
inputs=[attack_from, attack_subject, attack_body, model_selector, defense_toggle, user_info],
|
1921 |
+
outputs=[final_output_display, results_display, trace_display, email1_display, email2_display, email3_display, error_modal_html, flagged_accordion, flagged_content_display]
|
1922 |
)
|
1923 |
|
1924 |
# Connect dismiss trigger to properly hide the modal
|
instruction_classifier.py
CHANGED
@@ -186,7 +186,7 @@ class InstructionClassifierSanitizer:
|
|
186 |
|
187 |
|
188 |
@spaces.GPU
|
189 |
-
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
|
190 |
"""
|
191 |
Sanitization function that also returns annotation data for flagged content.
|
192 |
|
@@ -194,11 +194,13 @@ class InstructionClassifierSanitizer:
|
|
194 |
tool_output: The raw tool output string
|
195 |
|
196 |
Returns:
|
197 |
-
Tuple of (sanitized_output, annotations) where
|
198 |
-
|
|
|
|
|
199 |
"""
|
200 |
if not tool_output or not tool_output.strip():
|
201 |
-
return tool_output, []
|
202 |
|
203 |
# Move model to target device (GPU) within @spaces.GPU decorated method
|
204 |
if self.device != self.target_device:
|
@@ -214,7 +216,7 @@ class InstructionClassifierSanitizer:
|
|
214 |
|
215 |
if not is_injection:
|
216 |
print("✅ No injection detected - returning original output")
|
217 |
-
return tool_output, []
|
218 |
|
219 |
print(f"🚨 Injection detected! Processing with extensions and annotations...")
|
220 |
|
@@ -233,12 +235,13 @@ class InstructionClassifierSanitizer:
|
|
233 |
# Step 5: Remove instruction tags and their content
|
234 |
sanitized_output = self._remove_instruction_tags(merged_tagged_text)
|
235 |
print(f"🔒 Sanitized output: {sanitized_output}")
|
236 |
-
|
|
|
237 |
|
238 |
except Exception as e:
|
239 |
print(f"❌ Error in instruction classifier sanitization: {e}")
|
240 |
# Return original output if sanitization fails
|
241 |
-
return tool_output, []
|
242 |
|
243 |
def _extract_annotations_from_tagged_text(self, tagged_text: str, original_text: str) -> List[Dict[str, any]]:
|
244 |
"""
|
@@ -700,23 +703,25 @@ def sanitize_tool_output_with_annotations(tool_output, defense_enabled=True):
|
|
700 |
defense_enabled: Whether defense is enabled (passed from agent)
|
701 |
|
702 |
Returns:
|
703 |
-
Tuple of (sanitized_output, annotations) where
|
704 |
-
|
|
|
|
|
705 |
"""
|
706 |
print(f"🔍 sanitize_tool_output_with_annotations called with: {tool_output[:100]}...")
|
707 |
|
708 |
# If defense is disabled globally, return original output with no annotations
|
709 |
if not defense_enabled:
|
710 |
print("⚠️ Defense disabled - returning original output without processing")
|
711 |
-
return tool_output, []
|
712 |
|
713 |
sanitizer = get_sanitizer()
|
714 |
if sanitizer is None:
|
715 |
print("⚠️ Instruction classifier not available, returning original output")
|
716 |
-
return tool_output, []
|
717 |
|
718 |
print("✅ Sanitizer found, processing with annotations...")
|
719 |
-
sanitized_output, annotations = sanitizer.sanitize_with_annotations(tool_output)
|
720 |
print(f"🔒 Sanitization complete, result: {sanitized_output[:100]}...")
|
721 |
print(f"📝 Found {len(annotations)} annotations")
|
722 |
-
return sanitized_output, annotations
|
|
|
186 |
|
187 |
|
188 |
@spaces.GPU
|
189 |
+
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]], str]:
|
190 |
"""
|
191 |
Sanitization function that also returns annotation data for flagged content.
|
192 |
|
|
|
194 |
tool_output: The raw tool output string
|
195 |
|
196 |
Returns:
|
197 |
+
Tuple of (sanitized_output, annotations, merged_tagged_text) where:
|
198 |
+
- sanitized_output: cleaned text with instruction content removed
|
199 |
+
- annotations: position information for content flagged by classifier
|
200 |
+
- merged_tagged_text: text with <instruction> tags showing detected content
|
201 |
"""
|
202 |
if not tool_output or not tool_output.strip():
|
203 |
+
return tool_output, [], tool_output
|
204 |
|
205 |
# Move model to target device (GPU) within @spaces.GPU decorated method
|
206 |
if self.device != self.target_device:
|
|
|
216 |
|
217 |
if not is_injection:
|
218 |
print("✅ No injection detected - returning original output")
|
219 |
+
return tool_output, [], tool_output
|
220 |
|
221 |
print(f"🚨 Injection detected! Processing with extensions and annotations...")
|
222 |
|
|
|
235 |
# Step 5: Remove instruction tags and their content
|
236 |
sanitized_output = self._remove_instruction_tags(merged_tagged_text)
|
237 |
print(f"🔒 Sanitized output: {sanitized_output}")
|
238 |
+
print(f"🔍 DEBUG SANITIZER: Returning merged_tagged_text: '{merged_tagged_text}'")
|
239 |
+
return sanitized_output, annotations, merged_tagged_text
|
240 |
|
241 |
except Exception as e:
|
242 |
print(f"❌ Error in instruction classifier sanitization: {e}")
|
243 |
# Return original output if sanitization fails
|
244 |
+
return tool_output, [], tool_output
|
245 |
|
246 |
def _extract_annotations_from_tagged_text(self, tagged_text: str, original_text: str) -> List[Dict[str, any]]:
|
247 |
"""
|
|
|
703 |
defense_enabled: Whether defense is enabled (passed from agent)
|
704 |
|
705 |
Returns:
|
706 |
+
Tuple of (sanitized_output, annotations, merged_tagged_text) where:
|
707 |
+
- sanitized_output: cleaned text with instruction content removed
|
708 |
+
- annotations: position information for content flagged by classifier
|
709 |
+
- merged_tagged_text: text with <instruction> tags showing detected content
|
710 |
"""
|
711 |
print(f"🔍 sanitize_tool_output_with_annotations called with: {tool_output[:100]}...")
|
712 |
|
713 |
# If defense is disabled globally, return original output with no annotations
|
714 |
if not defense_enabled:
|
715 |
print("⚠️ Defense disabled - returning original output without processing")
|
716 |
+
return tool_output, [], tool_output
|
717 |
|
718 |
sanitizer = get_sanitizer()
|
719 |
if sanitizer is None:
|
720 |
print("⚠️ Instruction classifier not available, returning original output")
|
721 |
+
return tool_output, [], tool_output
|
722 |
|
723 |
print("✅ Sanitizer found, processing with annotations...")
|
724 |
+
sanitized_output, annotations, merged_tagged_text = sanitizer.sanitize_with_annotations(tool_output)
|
725 |
print(f"🔒 Sanitization complete, result: {sanitized_output[:100]}...")
|
726 |
print(f"📝 Found {len(annotations)} annotations")
|
727 |
+
return sanitized_output, annotations, merged_tagged_text
|
utils.py
CHANGED
@@ -1,147 +1,19 @@
|
|
1 |
import json
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
-
from torch.utils.data import Dataset
|
5 |
-
from transformers import
|
6 |
-
import numpy as np
|
7 |
-
from tqdm import tqdm
|
8 |
import re
|
9 |
-
from typing import List,
|
10 |
import warnings
|
11 |
import logging
|
12 |
import os
|
13 |
-
from datetime import datetime
|
14 |
-
from sklearn.utils.class_weight import compute_class_weight
|
15 |
-
import torch.nn.functional as F
|
16 |
|
17 |
# Disable tokenizer parallelism to avoid forking warnings
|
18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
|
20 |
warnings.filterwarnings('ignore')
|
21 |
|
22 |
-
def set_random_seeds(seed=42):
|
23 |
-
"""Set random seeds for reproducibility"""
|
24 |
-
import random
|
25 |
-
import numpy as np
|
26 |
-
import torch
|
27 |
-
|
28 |
-
random.seed(seed)
|
29 |
-
np.random.seed(seed)
|
30 |
-
torch.manual_seed(seed)
|
31 |
-
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
32 |
-
|
33 |
-
# Make CuDNN deterministic (slower but reproducible)
|
34 |
-
torch.backends.cudnn.deterministic = True
|
35 |
-
torch.backends.cudnn.benchmark = False
|
36 |
-
|
37 |
-
def setup_logging(log_dir='data/logs'):
|
38 |
-
"""Setup logging configuration"""
|
39 |
-
# Create logs directory if it doesn't exist
|
40 |
-
os.makedirs(log_dir, exist_ok=True)
|
41 |
-
|
42 |
-
# Create timestamp for log file
|
43 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
44 |
-
log_file = os.path.join(log_dir, f'training_log_{timestamp}.log')
|
45 |
-
|
46 |
-
# Configure logging
|
47 |
-
logging.basicConfig(
|
48 |
-
level=logging.INFO, # Back to INFO level
|
49 |
-
format='%(asctime)s - %(levelname)s - %(message)s',
|
50 |
-
handlers=[
|
51 |
-
logging.FileHandler(log_file),
|
52 |
-
logging.StreamHandler() # Also print to console
|
53 |
-
]
|
54 |
-
)
|
55 |
-
|
56 |
-
logger = logging.getLogger(__name__)
|
57 |
-
logger.info(f"Logging initialized. Log file: {log_file}")
|
58 |
-
return logger, log_file
|
59 |
-
|
60 |
-
def check_gpu_availability():
|
61 |
-
"""Check and print GPU availability information"""
|
62 |
-
logger = logging.getLogger(__name__)
|
63 |
-
logger.info("=== GPU Availability Check ===")
|
64 |
-
|
65 |
-
if torch.backends.mps.is_available():
|
66 |
-
logger.info("✓ MPS (Apple Silicon GPU) is available")
|
67 |
-
if torch.backends.mps.is_built():
|
68 |
-
logger.info("✓ MPS is built into PyTorch")
|
69 |
-
else:
|
70 |
-
logger.info("✗ MPS is not built into PyTorch")
|
71 |
-
else:
|
72 |
-
logger.info("✗ MPS (Apple Silicon GPU) is not available")
|
73 |
-
|
74 |
-
if torch.cuda.is_available():
|
75 |
-
logger.info(f"✓ CUDA is available (GPU count: {torch.cuda.device_count()})")
|
76 |
-
else:
|
77 |
-
logger.info("✗ CUDA is not available")
|
78 |
-
|
79 |
-
logger.info(f"PyTorch version: {torch.__version__}")
|
80 |
-
logger.info("=" * 50)
|
81 |
-
|
82 |
-
def calculate_class_weights(dataset):
|
83 |
-
"""Calculate class weights for imbalanced dataset using BERT paper approach"""
|
84 |
-
logger = logging.getLogger(__name__)
|
85 |
-
|
86 |
-
# Collect all labels from the dataset (BERT approach: only first subtokens have real labels)
|
87 |
-
all_labels = []
|
88 |
-
for window_data in dataset.processed_data:
|
89 |
-
# Filter out -100 labels (special tokens + subsequent subtokens of same word)
|
90 |
-
# This gives us true word-level class distribution
|
91 |
-
valid_labels = [label for label in window_data['subword_labels'] if label != -100]
|
92 |
-
all_labels.extend(valid_labels)
|
93 |
-
|
94 |
-
# Convert to numpy array
|
95 |
-
y = np.array(all_labels)
|
96 |
-
|
97 |
-
# Calculate class weights using sklearn
|
98 |
-
classes = np.unique(y)
|
99 |
-
class_weights = compute_class_weight('balanced', classes=classes, y=y)
|
100 |
-
|
101 |
-
# Create weight tensor
|
102 |
-
weight_tensor = torch.FloatTensor(class_weights)
|
103 |
-
|
104 |
-
logger.info(f"Word-level class distribution: {np.bincount(y)}")
|
105 |
-
logger.info(f"Class 0 (Non-instruction words): {np.sum(y == 0)} words ({np.sum(y == 0)/len(y)*100:.1f}%)")
|
106 |
-
logger.info(f"Class 1 (Instruction words): {np.sum(y == 1)} words ({np.sum(y == 1)/len(y)*100:.1f}%)")
|
107 |
-
logger.info(f"Calculated class weights (word-level): {class_weights}")
|
108 |
-
logger.info(f" Weight for class 0 (Non-instruction): {class_weights[0]:.4f}")
|
109 |
-
logger.info(f" Weight for class 1 (Instruction): {class_weights[1]:.4f}")
|
110 |
-
|
111 |
-
return weight_tensor
|
112 |
-
|
113 |
-
class FocalLoss(nn.Module):
|
114 |
-
"""Focal Loss for addressing class imbalance"""
|
115 |
-
def __init__(self, alpha=1, gamma=2, ignore_index=-100):
|
116 |
-
super(FocalLoss, self).__init__()
|
117 |
-
self.alpha = alpha
|
118 |
-
self.gamma = gamma
|
119 |
-
self.ignore_index = ignore_index
|
120 |
-
|
121 |
-
def forward(self, inputs, targets):
|
122 |
-
# Flatten inputs and targets
|
123 |
-
inputs = inputs.view(-1, inputs.size(-1))
|
124 |
-
targets = targets.view(-1)
|
125 |
-
|
126 |
-
# Create mask for non-ignored indices
|
127 |
-
mask = targets != self.ignore_index
|
128 |
-
targets = targets[mask]
|
129 |
-
inputs = inputs[mask]
|
130 |
-
|
131 |
-
if len(targets) == 0:
|
132 |
-
return torch.tensor(0.0, requires_grad=True, device=inputs.device)
|
133 |
-
|
134 |
-
# Calculate cross entropy
|
135 |
-
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
136 |
-
|
137 |
-
# Calculate pt
|
138 |
-
pt = torch.exp(-ce_loss)
|
139 |
-
|
140 |
-
# Calculate focal loss
|
141 |
-
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
|
142 |
-
|
143 |
-
return focal_loss.mean()
|
144 |
-
|
145 |
class InstructionDataset(Dataset):
|
146 |
def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
|
147 |
window_size: int = 512, overlap: int = 100):
|
@@ -517,8 +389,6 @@ class TransformerInstructionClassifier(nn.Module):
|
|
517 |
# Setup loss function based on type
|
518 |
if loss_type == 'weighted_ce':
|
519 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
|
520 |
-
elif loss_type == 'focal':
|
521 |
-
self.loss_fct = FocalLoss(alpha=1, gamma=2, ignore_index=-100)
|
522 |
else:
|
523 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
524 |
|
|
|
1 |
import json
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from transformers import AutoModel
|
|
|
|
|
6 |
import re
|
7 |
+
from typing import List, Dict, Any
|
8 |
import warnings
|
9 |
import logging
|
10 |
import os
|
|
|
|
|
|
|
11 |
|
12 |
# Disable tokenizer parallelism to avoid forking warnings
|
13 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
14 |
|
15 |
warnings.filterwarnings('ignore')
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
class InstructionDataset(Dataset):
|
18 |
def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
|
19 |
window_size: int = 512, overlap: int = 100):
|
|
|
389 |
# Setup loss function based on type
|
390 |
if loss_type == 'weighted_ce':
|
391 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
|
|
|
|
|
392 |
else:
|
393 |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
394 |
|