|
import os |
|
import random |
|
import zipfile |
|
from difflib import Differ |
|
|
|
import gradio as gr |
|
import nltk |
|
import pandas as pd |
|
from findfile import find_files |
|
|
|
from anonymous_demo import TADCheckpointManager |
|
from textattack import Attacker |
|
from textattack.attack_recipes import ( |
|
BAEGarg2019, |
|
PWWSRen2019, |
|
TextFoolerJin2019, |
|
PSOZang2020, |
|
IGAWang2019, |
|
GeneticAlgorithmAlzantot2018, |
|
DeepWordBugGao2018, |
|
CLARE2020, |
|
) |
|
from textattack.attack_results import SuccessfulAttackResult |
|
from textattack.datasets import Dataset |
|
from textattack.models.wrappers import HuggingFaceModelWrapper |
|
|
|
z = zipfile.ZipFile("checkpoints.zip", "r") |
|
z.extractall(os.getcwd()) |
|
|
|
|
|
class ModelWrapper(HuggingFaceModelWrapper): |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def __call__(self, text_inputs, **kwargs): |
|
outputs = [] |
|
for text_input in text_inputs: |
|
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) |
|
outputs.append(raw_outputs["probs"]) |
|
return outputs |
|
|
|
|
|
class SentAttacker: |
|
def __init__(self, model, recipe_class=BAEGarg2019): |
|
model = model |
|
model_wrapper = ModelWrapper(model) |
|
|
|
recipe = recipe_class.build(model_wrapper) |
|
|
|
|
|
|
|
|
|
_dataset = [("", 0)] |
|
_dataset = Dataset(_dataset) |
|
|
|
self.attacker = Attacker(recipe, _dataset) |
|
|
|
|
|
def diff_texts(text1, text2): |
|
d = Differ() |
|
return [ |
|
(token[2:], token[0] if token[0] != " " else None) |
|
for token in d.compare(text1, text2) |
|
] |
|
|
|
|
|
def get_ensembled_tad_results(results): |
|
target_dict = {} |
|
for r in results: |
|
target_dict[r["label"]] = ( |
|
target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1 |
|
) |
|
|
|
return dict(zip(target_dict.values(), target_dict.keys()))[ |
|
max(target_dict.values()) |
|
] |
|
|
|
|
|
nltk.download("omw-1.4") |
|
|
|
sent_attackers = {} |
|
tad_classifiers = {} |
|
|
|
attack_recipes = { |
|
"bae": BAEGarg2019, |
|
"pwws": PWWSRen2019, |
|
"textfooler": TextFoolerJin2019, |
|
"pso": PSOZang2020, |
|
"iga": IGAWang2019, |
|
"ga": GeneticAlgorithmAlzantot2018, |
|
"deepwordbug": DeepWordBugGao2018, |
|
'clare': CLARE2020, |
|
} |
|
|
|
for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]: |
|
for dataset in [ |
|
"agnews10k", |
|
"amazon", |
|
"sst2", |
|
|
|
]: |
|
if "tad-{}".format(dataset) not in tad_classifiers: |
|
tad_classifiers[ |
|
"tad-{}".format(dataset) |
|
] = TADCheckpointManager.get_tad_text_classifier( |
|
"tad-{}".format(dataset).upper() |
|
) |
|
|
|
sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker( |
|
tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker] |
|
) |
|
tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[ |
|
"tad-{}pwws".format(dataset) |
|
] |
|
|
|
|
|
def get_sst2_example(): |
|
filter_key_words = [ |
|
".py", |
|
".md", |
|
"readme", |
|
"log", |
|
"result", |
|
"zip", |
|
".state_dict", |
|
".model", |
|
".png", |
|
"acc_", |
|
"f1_", |
|
".origin", |
|
".adv", |
|
".csv", |
|
] |
|
|
|
dataset_file = {"train": [], "test": [], "valid": []} |
|
dataset = "sst2" |
|
search_path = "./" |
|
task = "text_defense" |
|
dataset_file["test"] += find_files( |
|
search_path, |
|
[dataset, "test", task], |
|
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] |
|
+ filter_key_words, |
|
) |
|
|
|
for dat_type in ["test"]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
with open(data_file, mode="r", encoding="utf8") as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split("$LABEL$") |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
def get_agnews_example(): |
|
filter_key_words = [ |
|
".py", |
|
".md", |
|
"readme", |
|
"log", |
|
"result", |
|
"zip", |
|
".state_dict", |
|
".model", |
|
".png", |
|
"acc_", |
|
"f1_", |
|
".origin", |
|
".adv", |
|
".csv", |
|
] |
|
|
|
dataset_file = {"train": [], "test": [], "valid": []} |
|
dataset = "agnews" |
|
search_path = "./" |
|
task = "text_defense" |
|
dataset_file["test"] += find_files( |
|
search_path, |
|
[dataset, "test", task], |
|
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] |
|
+ filter_key_words, |
|
) |
|
for dat_type in ["test"]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
with open(data_file, mode="r", encoding="utf8") as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split("$LABEL$") |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
def get_amazon_example(): |
|
filter_key_words = [ |
|
".py", |
|
".md", |
|
"readme", |
|
"log", |
|
"result", |
|
"zip", |
|
".state_dict", |
|
".model", |
|
".png", |
|
"acc_", |
|
"f1_", |
|
".origin", |
|
".adv", |
|
".csv", |
|
] |
|
|
|
dataset_file = {"train": [], "test": [], "valid": []} |
|
dataset = "amazon" |
|
search_path = "./" |
|
task = "text_defense" |
|
dataset_file["test"] += find_files( |
|
search_path, |
|
[dataset, "test", task], |
|
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] |
|
+ filter_key_words, |
|
) |
|
|
|
for dat_type in ["test"]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
with open(data_file, mode="r", encoding="utf8") as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split("$LABEL$") |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
def get_imdb_example(): |
|
filter_key_words = [ |
|
".py", |
|
".md", |
|
"readme", |
|
"log", |
|
"result", |
|
"zip", |
|
".state_dict", |
|
".model", |
|
".png", |
|
"acc_", |
|
"f1_", |
|
".origin", |
|
".adv", |
|
".csv", |
|
] |
|
|
|
dataset_file = {"train": [], "test": [], "valid": []} |
|
dataset = "imdb" |
|
search_path = "./" |
|
task = "text_defense" |
|
dataset_file["test"] += find_files( |
|
search_path, |
|
[dataset, "test", task], |
|
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] |
|
+ filter_key_words, |
|
) |
|
|
|
for dat_type in ["test"]: |
|
data = [] |
|
label_set = set() |
|
for data_file in dataset_file[dat_type]: |
|
with open(data_file, mode="r", encoding="utf8") as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
text, label = line.split("$LABEL$") |
|
text = text.strip() |
|
label = int(label.strip()) |
|
data.append((text, label)) |
|
label_set.add(label) |
|
return data[random.randint(0, len(data))] |
|
|
|
|
|
cache = set() |
|
|
|
|
|
def generate_adversarial_example(dataset, attacker, text=None, label=None): |
|
if not text or text in cache: |
|
if "agnews" in dataset.lower(): |
|
text, label = get_agnews_example() |
|
elif "sst2" in dataset.lower(): |
|
text, label = get_sst2_example() |
|
elif "amazon" in dataset.lower(): |
|
text, label = get_amazon_example() |
|
elif "imdb" in dataset.lower(): |
|
text, label = get_imdb_example() |
|
|
|
cache.add(text) |
|
|
|
result = None |
|
attack_result = sent_attackers[ |
|
"tad-{}{}".format(dataset.lower(), attacker.lower()) |
|
].attacker.simple_attack(text, int(label)) |
|
if isinstance(attack_result, SuccessfulAttackResult): |
|
if ( |
|
attack_result.perturbed_result.output |
|
!= attack_result.original_result.ground_truth_output |
|
) and ( |
|
attack_result.original_result.output |
|
== attack_result.original_result.ground_truth_output |
|
): |
|
|
|
result = tad_classifiers["tad-{}".format(dataset.lower())].infer( |
|
attack_result.perturbed_result.attacked_text.text |
|
+ "!ref!{},{},{}".format( |
|
attack_result.original_result.ground_truth_output, |
|
1, |
|
attack_result.perturbed_result.output, |
|
), |
|
print_result=True, |
|
defense="pwws", |
|
) |
|
|
|
if result: |
|
classification_df = {} |
|
classification_df["is_repaired"] = result["is_fixed"] |
|
classification_df["pred_label"] = result["label"] |
|
classification_df["confidence"] = round(result["confidence"], 3) |
|
classification_df["is_correct"] = result["ref_label_check"] |
|
|
|
advdetection_df = {} |
|
if result["is_adv_label"] != "0": |
|
advdetection_df["is_adversarial"] = { |
|
"0": False, |
|
"1": True, |
|
0: False, |
|
1: True, |
|
}[result["is_adv_label"]] |
|
advdetection_df["perturbed_label"] = result["perturbed_label"] |
|
advdetection_df["confidence"] = round(result["is_adv_confidence"], 3) |
|
|
|
|
|
|
|
else: |
|
return generate_adversarial_example(dataset, attacker) |
|
|
|
return ( |
|
text, |
|
label, |
|
result["restored_text"], |
|
result["label"], |
|
attack_result.perturbed_result.attacked_text.text, |
|
diff_texts(text, text), |
|
diff_texts(text, attack_result.perturbed_result.attacked_text.text), |
|
diff_texts(text, result["restored_text"]), |
|
attack_result.perturbed_result.output, |
|
pd.DataFrame(classification_df, index=[0]), |
|
pd.DataFrame(advdetection_df, index=[0]), |
|
) |
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown( |
|
"# <p align='center'> Reactive Perturbation Defocusing for Textual Adversarial Defense </p> " |
|
) |
|
|
|
gr.Markdown("## <p align='center'>Clarifications</p>") |
|
gr.Markdown( |
|
"- This demo has no mechanism to ensure the adversarial example will be correctly repaired by RPD." |
|
" The repair success rate is actually the performance reported in the paper (approximately up to 97%.)" |
|
) |
|
gr.Markdown( |
|
"- The red (+) and green (-) colors in the character edition indicate the character is added " |
|
"or deleted in the adversarial example compared to the original input natural example." |
|
) |
|
gr.Markdown( |
|
"- The adversarial example and repaired adversarial example may be unnatural to read, " |
|
"while it is because the attackers usually generate unnatural perturbations." |
|
"RPD does not introduce additional unnatural perturbations." |
|
) |
|
gr.Markdown( |
|
"- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense " |
|
". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods." |
|
) |
|
gr.Markdown( |
|
"- The DeepWordBug is an unknown attacker to RPD's adversarial detector, which shows the robustness of RPD." |
|
) |
|
|
|
gr.Markdown("## <p align='center'>Natural Example Input</p>") |
|
with gr.Group(): |
|
with gr.Row(): |
|
input_dataset = gr.Radio( |
|
choices=["SST2", "AGNews10K", "Amazon"], |
|
value="SST2", |
|
label="Select a testing dataset and an adversarial attacker to generate an adversarial example.", |
|
) |
|
input_attacker = gr.Radio( |
|
choices=[ |
|
"BAE", |
|
"PWWS", |
|
"TextFooler", |
|
"DeepWordBug" |
|
], |
|
value="TextFooler", |
|
label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.", |
|
) |
|
with gr.Group(): |
|
with gr.Row(): |
|
input_sentence = gr.Textbox( |
|
placeholder="Input a natural example...", |
|
label="Alternatively, input a natural example and its original label to generate an adversarial example.", |
|
) |
|
input_label = gr.Textbox( |
|
placeholder="Original label...", label="Original Label" |
|
) |
|
|
|
button_gen = gr.Button( |
|
"Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )", |
|
variant="primary", |
|
) |
|
|
|
gr.Markdown( |
|
"## <p align='center'>Generated Adversarial Example and Repaired Adversarial Example</p>" |
|
) |
|
with gr.Group(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
output_original_example = gr.Textbox(label="Original Example") |
|
output_original_label = gr.Textbox(label="Original Label") |
|
with gr.Row(): |
|
output_adv_example = gr.Textbox(label="Adversarial Example") |
|
output_adv_label = gr.Textbox(label="Perturbed Label") |
|
with gr.Row(): |
|
output_repaired_example = gr.Textbox( |
|
label="Repaired Adversarial Example by RPD" |
|
) |
|
output_repaired_label = gr.Textbox(label="Repaired Label") |
|
|
|
gr.Markdown( |
|
"## <p align='center'>The Output of Reactive Perturbation Defocusing</p>" |
|
) |
|
with gr.Group(): |
|
output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result") |
|
gr.Markdown( |
|
"The is_adversarial field indicates an adversarial example is detected. " |
|
"The perturbed_label is the predicted label of the adversarial example. " |
|
"The confidence field represents the confidence of the predicted adversarial example detection. " |
|
) |
|
output_df = gr.DataFrame(label="Repaired Standard Classification Result") |
|
gr.Markdown( |
|
"If is_repaired=true, it has been repaired by RPD. " |
|
"The pred_label field indicates the standard classification result. " |
|
"The confidence field represents the confidence of the predicted label. " |
|
"The is_correct field indicates whether the predicted label is correct." |
|
) |
|
|
|
gr.Markdown("## <p align='center'>Example Comparisons</p>") |
|
ori_text_diff = gr.HighlightedText( |
|
label="The Original Natural Example", |
|
combine_adjacent=True, |
|
) |
|
adv_text_diff = gr.HighlightedText( |
|
label="Character Editions of Adversarial Example Compared to the Natural Example", |
|
combine_adjacent=True, |
|
) |
|
restored_text_diff = gr.HighlightedText( |
|
label="Character Editions of Repaired Adversarial Example Compared to the Natural Example", |
|
combine_adjacent=True, |
|
) |
|
|
|
|
|
button_gen.click( |
|
fn=generate_adversarial_example, |
|
inputs=[input_dataset, input_attacker, input_sentence, input_label], |
|
outputs=[ |
|
output_original_example, |
|
output_original_label, |
|
output_repaired_example, |
|
output_repaired_label, |
|
output_adv_example, |
|
ori_text_diff, |
|
adv_text_diff, |
|
restored_text_diff, |
|
output_adv_label, |
|
output_df, |
|
output_is_adv_df, |
|
], |
|
) |
|
|
|
demo.launch() |
|
|