anonymous8
		
	commited on
		
		
					Commit 
							
							·
						
						4f6b345
	
1
								Parent(s):
							
							04b0636
								
update
Browse files- .gitignore +1 -1
- README.md +1 -1
- app.py +174 -374
- requirements.txt +2 -1
- utils.py +234 -0
    	
        .gitignore
    CHANGED
    
    | @@ -2,7 +2,7 @@ | |
| 2 | 
             
            *.cache
         | 
| 3 | 
             
            *.dev.py
         | 
| 4 | 
             
            state_dict/
         | 
| 5 | 
            -
             | 
| 6 | 
             
            # Byte-compiled / optimized / DLL files
         | 
| 7 | 
             
            __pycache__/
         | 
| 8 | 
             
            *.py[cod]
         | 
|  | |
| 2 | 
             
            *.cache
         | 
| 3 | 
             
            *.dev.py
         | 
| 4 | 
             
            state_dict/
         | 
| 5 | 
            +
            TAD*/
         | 
| 6 | 
             
            # Byte-compiled / optimized / DLL files
         | 
| 7 | 
             
            __pycache__/
         | 
| 8 | 
             
            *.py[cod]
         | 
    	
        README.md
    CHANGED
    
    | @@ -4,7 +4,7 @@ emoji: 🛡️ | |
| 4 | 
             
            colorFrom: gray
         | 
| 5 | 
             
            colorTo: green
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
|  | |
| 4 | 
             
            colorFrom: gray
         | 
| 5 | 
             
            colorTo: green
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version:  3.20.1
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,15 +1,13 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
            -
            import random
         | 
| 3 | 
             
            import zipfile
         | 
| 4 | 
            -
            from difflib import Differ
         | 
| 5 |  | 
| 6 | 
             
            import gradio as gr
         | 
| 7 | 
             
            import nltk
         | 
| 8 | 
             
            import pandas as pd
         | 
| 9 | 
            -
             | 
|  | |
| 10 |  | 
| 11 | 
             
            from anonymous_demo import TADCheckpointManager
         | 
| 12 | 
            -
            from textattack import Attacker
         | 
| 13 | 
             
            from textattack.attack_recipes import (
         | 
| 14 | 
             
                BAEGarg2019,
         | 
| 15 | 
             
                PWWSRen2019,
         | 
| @@ -21,60 +19,7 @@ from textattack.attack_recipes import ( | |
| 21 | 
             
                CLARE2020,
         | 
| 22 | 
             
            )
         | 
| 23 | 
             
            from textattack.attack_results import SuccessfulAttackResult
         | 
| 24 | 
            -
            from  | 
| 25 | 
            -
            from textattack.models.wrappers import HuggingFaceModelWrapper
         | 
| 26 | 
            -
             | 
| 27 | 
            -
            z = zipfile.ZipFile("checkpoints.zip", "r")
         | 
| 28 | 
            -
            z.extractall(os.getcwd())
         | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
            class ModelWrapper(HuggingFaceModelWrapper):
         | 
| 32 | 
            -
                def __init__(self, model):
         | 
| 33 | 
            -
                    self.model = model  # pipeline = pipeline
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                def __call__(self, text_inputs, **kwargs):
         | 
| 36 | 
            -
                    outputs = []
         | 
| 37 | 
            -
                    for text_input in text_inputs:
         | 
| 38 | 
            -
                        raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
         | 
| 39 | 
            -
                        outputs.append(raw_outputs["probs"])
         | 
| 40 | 
            -
                    return outputs
         | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
            class SentAttacker:
         | 
| 44 | 
            -
                def __init__(self, model, recipe_class=BAEGarg2019):
         | 
| 45 | 
            -
                    model = model
         | 
| 46 | 
            -
                    model_wrapper = ModelWrapper(model)
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                    recipe = recipe_class.build(model_wrapper)
         | 
| 49 | 
            -
                    # WordNet defaults to english. Set the default language to French ('fra')
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                    # recipe.transformation.language = "en"
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                    _dataset = [("", 0)]
         | 
| 54 | 
            -
                    _dataset = Dataset(_dataset)
         | 
| 55 | 
            -
             | 
| 56 | 
            -
                    self.attacker = Attacker(recipe, _dataset)
         | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
            def diff_texts(text1, text2):
         | 
| 60 | 
            -
                d = Differ()
         | 
| 61 | 
            -
                return [
         | 
| 62 | 
            -
                    (token[2:], token[0] if token[0] != " " else None)
         | 
| 63 | 
            -
                    for token in d.compare(text1, text2)
         | 
| 64 | 
            -
                ]
         | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
            def get_ensembled_tad_results(results):
         | 
| 68 | 
            -
                target_dict = {}
         | 
| 69 | 
            -
                for r in results:
         | 
| 70 | 
            -
                    target_dict[r["label"]] = (
         | 
| 71 | 
            -
                        target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1
         | 
| 72 | 
            -
                    )
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                return dict(zip(target_dict.values(), target_dict.keys()))[
         | 
| 75 | 
            -
                    max(target_dict.values())
         | 
| 76 | 
            -
                ]
         | 
| 77 | 
            -
             | 
| 78 |  | 
| 79 | 
             
            nltk.download("omw-1.4")
         | 
| 80 |  | 
| @@ -89,204 +34,37 @@ attack_recipes = { | |
| 89 | 
             
                "iga": IGAWang2019,
         | 
| 90 | 
             
                "ga": GeneticAlgorithmAlzantot2018,
         | 
| 91 | 
             
                "deepwordbug": DeepWordBugGao2018,
         | 
| 92 | 
            -
                 | 
| 93 | 
             
            }
         | 
| 94 |  | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
                     | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
                         | 
| 106 | 
            -
             | 
| 107 | 
            -
                         | 
| 108 | 
            -
             | 
| 109 | 
            -
                     | 
| 110 | 
            -
                         | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
            def get_sst2_example():
         | 
| 118 | 
            -
                filter_key_words = [
         | 
| 119 | 
            -
                    ".py",
         | 
| 120 | 
            -
                    ".md",
         | 
| 121 | 
            -
                    "readme",
         | 
| 122 | 
            -
                    "log",
         | 
| 123 | 
            -
                    "result",
         | 
| 124 | 
            -
                    "zip",
         | 
| 125 | 
            -
                    ".state_dict",
         | 
| 126 | 
            -
                    ".model",
         | 
| 127 | 
            -
                    ".png",
         | 
| 128 | 
            -
                    "acc_",
         | 
| 129 | 
            -
                    "f1_",
         | 
| 130 | 
            -
                    ".origin",
         | 
| 131 | 
            -
                    ".adv",
         | 
| 132 | 
            -
                    ".csv",
         | 
| 133 | 
            -
                ]
         | 
| 134 | 
            -
             | 
| 135 | 
            -
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 136 | 
            -
                dataset = "sst2"
         | 
| 137 | 
            -
                search_path = "./"
         | 
| 138 | 
            -
                task = "text_defense"
         | 
| 139 | 
            -
                dataset_file["test"] += find_files(
         | 
| 140 | 
            -
                    search_path,
         | 
| 141 | 
            -
                    [dataset, "test", task],
         | 
| 142 | 
            -
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 143 | 
            -
                    + filter_key_words,
         | 
| 144 | 
            -
                )
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                for dat_type in ["test"]:
         | 
| 147 | 
            -
                    data = []
         | 
| 148 | 
            -
                    label_set = set()
         | 
| 149 | 
            -
                    for data_file in dataset_file[dat_type]:
         | 
| 150 | 
            -
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 151 | 
            -
                            lines = fin.readlines()
         | 
| 152 | 
            -
                            for line in lines:
         | 
| 153 | 
            -
                                text, label = line.split("$LABEL$")
         | 
| 154 | 
            -
                                text = text.strip()
         | 
| 155 | 
            -
                                label = int(label.strip())
         | 
| 156 | 
            -
                                data.append((text, label))
         | 
| 157 | 
            -
                                label_set.add(label)
         | 
| 158 | 
            -
                    return data[random.randint(0, len(data))]
         | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
            def get_agnews_example():
         | 
| 162 | 
            -
                filter_key_words = [
         | 
| 163 | 
            -
                    ".py",
         | 
| 164 | 
            -
                    ".md",
         | 
| 165 | 
            -
                    "readme",
         | 
| 166 | 
            -
                    "log",
         | 
| 167 | 
            -
                    "result",
         | 
| 168 | 
            -
                    "zip",
         | 
| 169 | 
            -
                    ".state_dict",
         | 
| 170 | 
            -
                    ".model",
         | 
| 171 | 
            -
                    ".png",
         | 
| 172 | 
            -
                    "acc_",
         | 
| 173 | 
            -
                    "f1_",
         | 
| 174 | 
            -
                    ".origin",
         | 
| 175 | 
            -
                    ".adv",
         | 
| 176 | 
            -
                    ".csv",
         | 
| 177 | 
            -
                ]
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 180 | 
            -
                dataset = "agnews"
         | 
| 181 | 
            -
                search_path = "./"
         | 
| 182 | 
            -
                task = "text_defense"
         | 
| 183 | 
            -
                dataset_file["test"] += find_files(
         | 
| 184 | 
            -
                    search_path,
         | 
| 185 | 
            -
                    [dataset, "test", task],
         | 
| 186 | 
            -
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 187 | 
            -
                    + filter_key_words,
         | 
| 188 | 
            -
                )
         | 
| 189 | 
            -
                for dat_type in ["test"]:
         | 
| 190 | 
            -
                    data = []
         | 
| 191 | 
            -
                    label_set = set()
         | 
| 192 | 
            -
                    for data_file in dataset_file[dat_type]:
         | 
| 193 | 
            -
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 194 | 
            -
                            lines = fin.readlines()
         | 
| 195 | 
            -
                            for line in lines:
         | 
| 196 | 
            -
                                text, label = line.split("$LABEL$")
         | 
| 197 | 
            -
                                text = text.strip()
         | 
| 198 | 
            -
                                label = int(label.strip())
         | 
| 199 | 
            -
                                data.append((text, label))
         | 
| 200 | 
            -
                                label_set.add(label)
         | 
| 201 | 
            -
                    return data[random.randint(0, len(data))]
         | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
            def get_amazon_example():
         | 
| 205 | 
            -
                filter_key_words = [
         | 
| 206 | 
            -
                    ".py",
         | 
| 207 | 
            -
                    ".md",
         | 
| 208 | 
            -
                    "readme",
         | 
| 209 | 
            -
                    "log",
         | 
| 210 | 
            -
                    "result",
         | 
| 211 | 
            -
                    "zip",
         | 
| 212 | 
            -
                    ".state_dict",
         | 
| 213 | 
            -
                    ".model",
         | 
| 214 | 
            -
                    ".png",
         | 
| 215 | 
            -
                    "acc_",
         | 
| 216 | 
            -
                    "f1_",
         | 
| 217 | 
            -
                    ".origin",
         | 
| 218 | 
            -
                    ".adv",
         | 
| 219 | 
            -
                    ".csv",
         | 
| 220 | 
            -
                ]
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 223 | 
            -
                dataset = "amazon"
         | 
| 224 | 
            -
                search_path = "./"
         | 
| 225 | 
            -
                task = "text_defense"
         | 
| 226 | 
            -
                dataset_file["test"] += find_files(
         | 
| 227 | 
            -
                    search_path,
         | 
| 228 | 
            -
                    [dataset, "test", task],
         | 
| 229 | 
            -
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 230 | 
            -
                    + filter_key_words,
         | 
| 231 | 
            -
                )
         | 
| 232 | 
            -
             | 
| 233 | 
            -
                for dat_type in ["test"]:
         | 
| 234 | 
            -
                    data = []
         | 
| 235 | 
            -
                    label_set = set()
         | 
| 236 | 
            -
                    for data_file in dataset_file[dat_type]:
         | 
| 237 | 
            -
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 238 | 
            -
                            lines = fin.readlines()
         | 
| 239 | 
            -
                            for line in lines:
         | 
| 240 | 
            -
                                text, label = line.split("$LABEL$")
         | 
| 241 | 
            -
                                text = text.strip()
         | 
| 242 | 
            -
                                label = int(label.strip())
         | 
| 243 | 
            -
                                data.append((text, label))
         | 
| 244 | 
            -
                                label_set.add(label)
         | 
| 245 | 
            -
                    return data[random.randint(0, len(data))]
         | 
| 246 | 
            -
             | 
| 247 | 
            -
             | 
| 248 | 
            -
            def get_imdb_example():
         | 
| 249 | 
            -
                filter_key_words = [
         | 
| 250 | 
            -
                    ".py",
         | 
| 251 | 
            -
                    ".md",
         | 
| 252 | 
            -
                    "readme",
         | 
| 253 | 
            -
                    "log",
         | 
| 254 | 
            -
                    "result",
         | 
| 255 | 
            -
                    "zip",
         | 
| 256 | 
            -
                    ".state_dict",
         | 
| 257 | 
            -
                    ".model",
         | 
| 258 | 
            -
                    ".png",
         | 
| 259 | 
            -
                    "acc_",
         | 
| 260 | 
            -
                    "f1_",
         | 
| 261 | 
            -
                    ".origin",
         | 
| 262 | 
            -
                    ".adv",
         | 
| 263 | 
            -
                    ".csv",
         | 
| 264 | 
            -
                ]
         | 
| 265 | 
            -
             | 
| 266 | 
            -
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 267 | 
            -
                dataset = "imdb"
         | 
| 268 | 
            -
                search_path = "./"
         | 
| 269 | 
            -
                task = "text_defense"
         | 
| 270 | 
            -
                dataset_file["test"] += find_files(
         | 
| 271 | 
            -
                    search_path,
         | 
| 272 | 
            -
                    [dataset, "test", task],
         | 
| 273 | 
            -
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 274 | 
            -
                    + filter_key_words,
         | 
| 275 | 
            -
                )
         | 
| 276 |  | 
| 277 | 
            -
             | 
| 278 | 
            -
             | 
| 279 | 
            -
             | 
| 280 | 
            -
             | 
| 281 | 
            -
             | 
| 282 | 
            -
             | 
| 283 | 
            -
                            for line in lines:
         | 
| 284 | 
            -
                                text, label = line.split("$LABEL$")
         | 
| 285 | 
            -
                                text = text.strip()
         | 
| 286 | 
            -
                                label = int(label.strip())
         | 
| 287 | 
            -
                                data.append((text, label))
         | 
| 288 | 
            -
                                label_set.add(label)
         | 
| 289 | 
            -
                    return data[random.randint(0, len(data))]
         | 
| 290 |  | 
| 291 |  | 
| 292 | 
             
            cache = set()
         | 
| @@ -311,11 +89,11 @@ def generate_adversarial_example(dataset, attacker, text=None, label=None): | |
| 311 | 
             
                ].attacker.simple_attack(text, int(label))
         | 
| 312 | 
             
                if isinstance(attack_result, SuccessfulAttackResult):
         | 
| 313 | 
             
                    if (
         | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 316 | 
             
                    ) and (
         | 
| 317 | 
            -
             | 
| 318 | 
            -
             | 
| 319 | 
             
                    ):
         | 
| 320 | 
             
                        # with defense
         | 
| 321 | 
             
                        result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
         | 
| @@ -367,133 +145,155 @@ def generate_adversarial_example(dataset, attacker, text=None, label=None): | |
| 367 | 
             
                )
         | 
| 368 |  | 
| 369 |  | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
                 | 
| 373 | 
            -
                     | 
| 374 | 
            -
             | 
| 375 | 
            -
             | 
| 376 | 
            -
             | 
| 377 | 
            -
             | 
| 378 | 
            -
                     | 
| 379 | 
            -
                     | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
                     | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
             | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
             | 
| 399 | 
            -
             | 
| 400 | 
            -
             | 
| 401 | 
            -
             | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 | 
            -
             | 
| 406 | 
            -
             | 
| 407 | 
            -
             | 
| 408 | 
            -
             | 
| 409 | 
            -
             | 
| 410 | 
            -
             | 
| 411 | 
            -
             | 
| 412 | 
            -
             | 
| 413 | 
            -
             | 
| 414 | 
            -
             | 
| 415 | 
            -
                        )
         | 
| 416 | 
             
                    with gr.Group():
         | 
| 417 | 
             
                        with gr.Row():
         | 
| 418 | 
            -
                             | 
| 419 | 
            -
                                 | 
| 420 | 
            -
                                 | 
|  | |
| 421 | 
             
                            )
         | 
| 422 | 
            -
                             | 
| 423 | 
            -
                                 | 
|  | |
|  | |
| 424 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 425 |  | 
| 426 | 
            -
             | 
| 427 | 
            -
                    "Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )",
         | 
| 428 | 
            -
                    variant="primary",
         | 
| 429 | 
            -
                )
         | 
| 430 |  | 
| 431 | 
            -
                gr.Markdown(
         | 
| 432 | 
            -
                    "## <p align='center'>Generated Adversarial Example and Repaired Adversarial Example</p>"
         | 
| 433 | 
            -
                )
         | 
| 434 | 
            -
                with gr.Group():
         | 
| 435 | 
             
                    with gr.Column():
         | 
| 436 | 
            -
                        with gr. | 
| 437 | 
            -
                             | 
| 438 | 
            -
             | 
| 439 | 
            -
             | 
| 440 | 
            -
                             | 
| 441 | 
            -
             | 
| 442 | 
            -
             | 
| 443 | 
            -
                             | 
| 444 | 
            -
                                 | 
| 445 | 
            -
             | 
| 446 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 447 |  | 
| 448 | 
            -
                gr.Markdown(
         | 
| 449 | 
            -
                    "## <p align='center'>The Output of Reactive Perturbation Defocusing</p>"
         | 
| 450 | 
            -
                )
         | 
| 451 | 
            -
                with gr.Group():
         | 
| 452 | 
            -
                    output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result")
         | 
| 453 | 
             
                    gr.Markdown(
         | 
| 454 | 
            -
                        " | 
| 455 | 
            -
                        "The perturbed_label is the predicted label of the adversarial example. "
         | 
| 456 | 
            -
                        "The confidence field represents the confidence of the predicted adversarial example detection. "
         | 
| 457 | 
             
                    )
         | 
| 458 | 
            -
                     | 
| 459 | 
            -
             | 
| 460 | 
            -
             | 
| 461 | 
            -
             | 
| 462 | 
            -
             | 
| 463 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 464 | 
             
                    )
         | 
| 465 |  | 
| 466 | 
            -
                 | 
| 467 | 
            -
                ori_text_diff = gr.HighlightedText(
         | 
| 468 | 
            -
                    label="The Original Natural Example",
         | 
| 469 | 
            -
                    combine_adjacent=True,
         | 
| 470 | 
            -
                )
         | 
| 471 | 
            -
                adv_text_diff = gr.HighlightedText(
         | 
| 472 | 
            -
                    label="Character Editions of Adversarial Example Compared to the Natural Example",
         | 
| 473 | 
            -
                    combine_adjacent=True,
         | 
| 474 | 
            -
                )
         | 
| 475 | 
            -
                restored_text_diff = gr.HighlightedText(
         | 
| 476 | 
            -
                    label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
         | 
| 477 | 
            -
                    combine_adjacent=True,
         | 
| 478 | 
            -
                )
         | 
| 479 | 
            -
             | 
| 480 | 
            -
                # Bind functions to buttons
         | 
| 481 | 
            -
                button_gen.click(
         | 
| 482 | 
            -
                    fn=generate_adversarial_example,
         | 
| 483 | 
            -
                    inputs=[input_dataset, input_attacker, input_sentence, input_label],
         | 
| 484 | 
            -
                    outputs=[
         | 
| 485 | 
            -
                        output_original_example,
         | 
| 486 | 
            -
                        output_original_label,
         | 
| 487 | 
            -
                        output_repaired_example,
         | 
| 488 | 
            -
                        output_repaired_label,
         | 
| 489 | 
            -
                        output_adv_example,
         | 
| 490 | 
            -
                        ori_text_diff,
         | 
| 491 | 
            -
                        adv_text_diff,
         | 
| 492 | 
            -
                        restored_text_diff,
         | 
| 493 | 
            -
                        output_adv_label,
         | 
| 494 | 
            -
                        output_df,
         | 
| 495 | 
            -
                        output_is_adv_df,
         | 
| 496 | 
            -
                    ],
         | 
| 497 | 
            -
                )
         | 
| 498 |  | 
| 499 | 
            -
            demo.launch()
         | 
|  | |
| 1 | 
             
            import os
         | 
|  | |
| 2 | 
             
            import zipfile
         | 
|  | |
| 3 |  | 
| 4 | 
             
            import gradio as gr
         | 
| 5 | 
             
            import nltk
         | 
| 6 | 
             
            import pandas as pd
         | 
| 7 | 
            +
            import requests
         | 
| 8 | 
            +
            from flask import Flask
         | 
| 9 |  | 
| 10 | 
             
            from anonymous_demo import TADCheckpointManager
         | 
|  | |
| 11 | 
             
            from textattack.attack_recipes import (
         | 
| 12 | 
             
                BAEGarg2019,
         | 
| 13 | 
             
                PWWSRen2019,
         | 
|  | |
| 19 | 
             
                CLARE2020,
         | 
| 20 | 
             
            )
         | 
| 21 | 
             
            from textattack.attack_results import SuccessfulAttackResult
         | 
| 22 | 
            +
            from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 23 |  | 
| 24 | 
             
            nltk.download("omw-1.4")
         | 
| 25 |  | 
|  | |
| 34 | 
             
                "iga": IGAWang2019,
         | 
| 35 | 
             
                "ga": GeneticAlgorithmAlzantot2018,
         | 
| 36 | 
             
                "deepwordbug": DeepWordBugGao2018,
         | 
| 37 | 
            +
                "clare": CLARE2020,
         | 
| 38 | 
             
            }
         | 
| 39 |  | 
| 40 | 
            +
            app = Flask(__name__)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def init():
         | 
| 44 | 
            +
                if not os.path.exists("TAD-SST2"):
         | 
| 45 | 
            +
                    z = zipfile.ZipFile("checkpoints.zip", "r")
         | 
| 46 | 
            +
                    z.extractall(os.getcwd())
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]:
         | 
| 49 | 
            +
                    for dataset in [
         | 
| 50 | 
            +
                        "agnews10k",
         | 
| 51 | 
            +
                        "amazon",
         | 
| 52 | 
            +
                        "sst2",
         | 
| 53 | 
            +
                        # 'imdb'
         | 
| 54 | 
            +
                    ]:
         | 
| 55 | 
            +
                        if "tad-{}".format(dataset) not in tad_classifiers:
         | 
| 56 | 
            +
                            tad_classifiers[
         | 
| 57 | 
            +
                                "tad-{}".format(dataset)
         | 
| 58 | 
            +
                            ] = TADCheckpointManager.get_tad_text_classifier(
         | 
| 59 | 
            +
                                "tad-{}".format(dataset).upper()
         | 
| 60 | 
            +
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 61 |  | 
| 62 | 
            +
                        sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker(
         | 
| 63 | 
            +
                            tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker]
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
                        tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[
         | 
| 66 | 
            +
                            "tad-{}pwws".format(dataset)
         | 
| 67 | 
            +
                        ]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 |  | 
| 69 |  | 
| 70 | 
             
            cache = set()
         | 
|  | |
| 89 | 
             
                ].attacker.simple_attack(text, int(label))
         | 
| 90 | 
             
                if isinstance(attack_result, SuccessfulAttackResult):
         | 
| 91 | 
             
                    if (
         | 
| 92 | 
            +
                            attack_result.perturbed_result.output
         | 
| 93 | 
            +
                            != attack_result.original_result.ground_truth_output
         | 
| 94 | 
             
                    ) and (
         | 
| 95 | 
            +
                            attack_result.original_result.output
         | 
| 96 | 
            +
                            == attack_result.original_result.ground_truth_output
         | 
| 97 | 
             
                    ):
         | 
| 98 | 
             
                        # with defense
         | 
| 99 | 
             
                        result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
         | 
|  | |
| 145 | 
             
                )
         | 
| 146 |  | 
| 147 |  | 
| 148 | 
            +
            def run_demo(dataset, attacker, text=None, label=None):
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                try:
         | 
| 151 | 
            +
                    data = {
         | 
| 152 | 
            +
                        "dataset": dataset,
         | 
| 153 | 
            +
                        "attacker": attacker,
         | 
| 154 | 
            +
                        "text": text,
         | 
| 155 | 
            +
                        "label": label,
         | 
| 156 | 
            +
                    }
         | 
| 157 | 
            +
                    response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data)
         | 
| 158 | 
            +
                    result = response.json()
         | 
| 159 | 
            +
                    print(response.json())
         | 
| 160 | 
            +
                    return (
         | 
| 161 | 
            +
                        result["text"],
         | 
| 162 | 
            +
                        result["label"],
         | 
| 163 | 
            +
                        result["restored_text"],
         | 
| 164 | 
            +
                        result["result_label"],
         | 
| 165 | 
            +
                        result["perturbed_text"],
         | 
| 166 | 
            +
                        result["text_diff"],
         | 
| 167 | 
            +
                        result["perturbed_diff"],
         | 
| 168 | 
            +
                        result["restored_diff"],
         | 
| 169 | 
            +
                        result["output"],
         | 
| 170 | 
            +
                        pd.DataFrame(result["classification_df"]),
         | 
| 171 | 
            +
                        pd.DataFrame(result["advdetection_df"]),
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
                except Exception as e:
         | 
| 174 | 
            +
                    print(e)
         | 
| 175 | 
            +
                    return generate_adversarial_example(dataset, attacker, text, label)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            if __name__ == "__main__":
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                init()
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                demo = gr.Blocks()
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                with demo:
         | 
| 184 | 
            +
                    gr.Markdown("<h1 align='center'>Reactive Perturbation Defocusing for Textual Adversarial Defense</h1>")
         | 
| 185 | 
            +
                    gr.Markdown("<h3 align='center'>Clarifications</h2>")
         | 
| 186 | 
            +
                    gr.Markdown("""
         | 
| 187 | 
            +
                - This demo has no mechanism to ensure the adversarial example will be correctly repaired by RPD. The repair success rate is actually the performance reported in the paper (approximately up to 97%).
         | 
| 188 | 
            +
                - The adversarial example and repaired adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations. RPD does not introduce additional unnatural perturbations.
         | 
| 189 | 
            +
                - To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense. RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods.
         | 
| 190 | 
            +
                - The DeepWordBug is an unknown attacker to the adversarial detector and reactive defense module. DeepWordBug has different attacking patterns from other attackers and shows the generalizability and robustness of RPD.
         | 
| 191 | 
            +
                """)
         | 
| 192 | 
            +
                    gr.Markdown("<h2 align='center'>Natural Example Input</h2>")
         | 
|  | |
| 193 | 
             
                    with gr.Group():
         | 
| 194 | 
             
                        with gr.Row():
         | 
| 195 | 
            +
                            input_dataset = gr.Radio(
         | 
| 196 | 
            +
                                choices=["SST2", "AGNews10K", "Amazon"],
         | 
| 197 | 
            +
                                value="Amazon",
         | 
| 198 | 
            +
                                label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
         | 
| 199 | 
             
                            )
         | 
| 200 | 
            +
                            input_attacker = gr.Radio(
         | 
| 201 | 
            +
                                choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"],
         | 
| 202 | 
            +
                                value="TextFooler",
         | 
| 203 | 
            +
                                label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
         | 
| 204 | 
             
                            )
         | 
| 205 | 
            +
                        with gr.Group():
         | 
| 206 | 
            +
                            with gr.Row():
         | 
| 207 | 
            +
                                input_sentence = gr.Textbox(
         | 
| 208 | 
            +
                                    placeholder="Input a natural example...",
         | 
| 209 | 
            +
                                    label="Alternatively, input a natural example and its original label to generate an adversarial example.",
         | 
| 210 | 
            +
                                )
         | 
| 211 | 
            +
                                input_label = gr.Textbox(
         | 
| 212 | 
            +
                                    placeholder="Original label...", label="Original Label"
         | 
| 213 | 
            +
                                )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    button_gen = gr.Button(
         | 
| 216 | 
            +
                        "Generate an adversarial example to repair using RPD (it will takes 1-10 minutes because no GPU is available)",
         | 
| 217 | 
            +
                        variant="primary",
         | 
| 218 | 
            +
                    )
         | 
| 219 |  | 
| 220 | 
            +
                    gr.Markdown("<h2 align='center'>Generated Adversarial Example and Repaired Adversarial Example</h2>")
         | 
|  | |
|  | |
|  | |
| 221 |  | 
|  | |
|  | |
|  | |
|  | |
| 222 | 
             
                    with gr.Column():
         | 
| 223 | 
            +
                        with gr.Group():
         | 
| 224 | 
            +
                            with gr.Row():
         | 
| 225 | 
            +
                                output_original_example = gr.Textbox(label="Original Example")
         | 
| 226 | 
            +
                                output_original_label = gr.Textbox(label="Original Label")
         | 
| 227 | 
            +
                            with gr.Row():
         | 
| 228 | 
            +
                                output_adv_example = gr.Textbox(label="Adversarial Example")
         | 
| 229 | 
            +
                                output_adv_label = gr.Textbox(label="Predicted Label of the Adversarial Example")
         | 
| 230 | 
            +
                            with gr.Row():
         | 
| 231 | 
            +
                                output_repaired_example = gr.Textbox(
         | 
| 232 | 
            +
                                    label="Repaired Adversarial Example by RPD"
         | 
| 233 | 
            +
                                )
         | 
| 234 | 
            +
                                output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    gr.Markdown("<h2 align='center'>Example Difference (Comparisons)</p>")
         | 
| 237 | 
            +
                    gr.Markdown("""
         | 
| 238 | 
            +
                <p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p>
         | 
| 239 | 
            +
                    """)
         | 
| 240 | 
            +
                    ori_text_diff = gr.HighlightedText(
         | 
| 241 | 
            +
                        label="The Original Natural Example",
         | 
| 242 | 
            +
                        combine_adjacent=True,
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
                    adv_text_diff = gr.HighlightedText(
         | 
| 245 | 
            +
                        label="Character Editions of Adversarial Example Compared to the Natural Example",
         | 
| 246 | 
            +
                        combine_adjacent=True,
         | 
| 247 | 
            +
                    )
         | 
| 248 | 
            +
                    restored_text_diff = gr.HighlightedText(
         | 
| 249 | 
            +
                        label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
         | 
| 250 | 
            +
                        combine_adjacent=True,
         | 
| 251 | 
            +
                    )
         | 
| 252 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 253 | 
             
                    gr.Markdown(
         | 
| 254 | 
            +
                        "## <h2 align='center'>The Output of Reactive Perturbation Defocusing</p>"
         | 
|  | |
|  | |
| 255 | 
             
                    )
         | 
| 256 | 
            +
                    with gr.Row():
         | 
| 257 | 
            +
                        with gr.Column():
         | 
| 258 | 
            +
                            with gr.Group():
         | 
| 259 | 
            +
                                output_is_adv_df = gr.DataFrame(
         | 
| 260 | 
            +
                                    label="Adversarial Example Detection Result"
         | 
| 261 | 
            +
                                )
         | 
| 262 | 
            +
                                gr.Markdown(
         | 
| 263 | 
            +
                                    "The is_adversarial field indicates if an adversarial example is detected. "
         | 
| 264 | 
            +
                                    "The perturbed_label is the predicted label of the adversarial example. "
         | 
| 265 | 
            +
                                    "The confidence field represents the confidence of the predicted adversarial example detection. "
         | 
| 266 | 
            +
                                )
         | 
| 267 | 
            +
                        with gr.Column():
         | 
| 268 | 
            +
                            with gr.Group():
         | 
| 269 | 
            +
                                output_df = gr.DataFrame(
         | 
| 270 | 
            +
                                    label="Repaired Standard Classification Result"
         | 
| 271 | 
            +
                                )
         | 
| 272 | 
            +
                                gr.Markdown(
         | 
| 273 | 
            +
                                    "If is_repaired=true, it has been repaired by RPD. "
         | 
| 274 | 
            +
                                    "The pred_label field indicates the standard classification result. "
         | 
| 275 | 
            +
                                    "The confidence field represents the confidence of the predicted label. "
         | 
| 276 | 
            +
                                    "The is_correct field indicates whether the predicted label is correct."
         | 
| 277 | 
            +
                                )
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # Bind functions to buttons
         | 
| 280 | 
            +
                    button_gen.click(
         | 
| 281 | 
            +
                        fn=run_demo,
         | 
| 282 | 
            +
                        inputs=[input_dataset, input_attacker, input_sentence, input_label],
         | 
| 283 | 
            +
                        outputs=[
         | 
| 284 | 
            +
                            output_original_example,
         | 
| 285 | 
            +
                            output_original_label,
         | 
| 286 | 
            +
                            output_repaired_example,
         | 
| 287 | 
            +
                            output_repaired_label,
         | 
| 288 | 
            +
                            output_adv_example,
         | 
| 289 | 
            +
                            ori_text_diff,
         | 
| 290 | 
            +
                            adv_text_diff,
         | 
| 291 | 
            +
                            restored_text_diff,
         | 
| 292 | 
            +
                            output_adv_label,
         | 
| 293 | 
            +
                            output_df,
         | 
| 294 | 
            +
                            output_is_adv_df,
         | 
| 295 | 
            +
                        ],
         | 
| 296 | 
             
                    )
         | 
| 297 |  | 
| 298 | 
            +
                demo.queue(2).launch()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 299 |  | 
|  | 
    	
        requirements.txt
    CHANGED
    
    | @@ -20,4 +20,5 @@ textattack[dev] | |
| 20 | 
             
            jieba
         | 
| 21 | 
             
            pycld2
         | 
| 22 | 
             
            OpenHowNet
         | 
| 23 | 
            -
            pinyin
         | 
|  | 
|  | |
| 20 | 
             
            jieba
         | 
| 21 | 
             
            pycld2
         | 
| 22 | 
             
            OpenHowNet
         | 
| 23 | 
            +
            pinyin
         | 
| 24 | 
            +
            flask
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,234 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            from difflib import Differ
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from textattack.attack_recipes import BAEGarg2019
         | 
| 5 | 
            +
            from textattack.datasets import Dataset
         | 
| 6 | 
            +
            from textattack.models.wrappers import HuggingFaceModelWrapper
         | 
| 7 | 
            +
            from findfile import find_files
         | 
| 8 | 
            +
            from flask import Flask
         | 
| 9 | 
            +
            from textattack import Attacker
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ModelWrapper(HuggingFaceModelWrapper):
         | 
| 13 | 
            +
                def __init__(self, model):
         | 
| 14 | 
            +
                    self.model = model  # pipeline = pipeline
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __call__(self, text_inputs, **kwargs):
         | 
| 17 | 
            +
                    outputs = []
         | 
| 18 | 
            +
                    for text_input in text_inputs:
         | 
| 19 | 
            +
                        raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
         | 
| 20 | 
            +
                        outputs.append(raw_outputs["probs"])
         | 
| 21 | 
            +
                    return outputs
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class SentAttacker:
         | 
| 25 | 
            +
                def __init__(self, model, recipe_class=BAEGarg2019):
         | 
| 26 | 
            +
                    model = model
         | 
| 27 | 
            +
                    model_wrapper = ModelWrapper(model)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    recipe = recipe_class.build(model_wrapper)
         | 
| 30 | 
            +
                    # WordNet defaults to english. Set the default language to French ('fra')
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    # recipe.transformation.language = "en"
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    _dataset = [("", 0)]
         | 
| 35 | 
            +
                    _dataset = Dataset(_dataset)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.attacker = Attacker(recipe, _dataset)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def diff_texts(text1, text2):
         | 
| 41 | 
            +
                d = Differ()
         | 
| 42 | 
            +
                return [
         | 
| 43 | 
            +
                    (token[2:], token[0] if token[0] != " " else None)
         | 
| 44 | 
            +
                    for token in d.compare(text1, text2)
         | 
| 45 | 
            +
                ]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def get_ensembled_tad_results(results):
         | 
| 49 | 
            +
                target_dict = {}
         | 
| 50 | 
            +
                for r in results:
         | 
| 51 | 
            +
                    target_dict[r["label"]] = (
         | 
| 52 | 
            +
                        target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                return dict(zip(target_dict.values(), target_dict.keys()))[
         | 
| 56 | 
            +
                    max(target_dict.values())
         | 
| 57 | 
            +
                ]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def get_sst2_example():
         | 
| 62 | 
            +
                filter_key_words = [
         | 
| 63 | 
            +
                    ".py",
         | 
| 64 | 
            +
                    ".md",
         | 
| 65 | 
            +
                    "readme",
         | 
| 66 | 
            +
                    "log",
         | 
| 67 | 
            +
                    "result",
         | 
| 68 | 
            +
                    "zip",
         | 
| 69 | 
            +
                    ".state_dict",
         | 
| 70 | 
            +
                    ".model",
         | 
| 71 | 
            +
                    ".png",
         | 
| 72 | 
            +
                    "acc_",
         | 
| 73 | 
            +
                    "f1_",
         | 
| 74 | 
            +
                    ".origin",
         | 
| 75 | 
            +
                    ".adv",
         | 
| 76 | 
            +
                    ".csv",
         | 
| 77 | 
            +
                ]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 80 | 
            +
                dataset = "sst2"
         | 
| 81 | 
            +
                search_path = "./"
         | 
| 82 | 
            +
                task = "text_defense"
         | 
| 83 | 
            +
                dataset_file["test"] += find_files(
         | 
| 84 | 
            +
                    search_path,
         | 
| 85 | 
            +
                    [dataset, "test", task],
         | 
| 86 | 
            +
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 87 | 
            +
                                + filter_key_words,
         | 
| 88 | 
            +
                )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                for dat_type in ["test"]:
         | 
| 91 | 
            +
                    data = []
         | 
| 92 | 
            +
                    label_set = set()
         | 
| 93 | 
            +
                    for data_file in dataset_file[dat_type]:
         | 
| 94 | 
            +
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 95 | 
            +
                            lines = fin.readlines()
         | 
| 96 | 
            +
                            for line in lines:
         | 
| 97 | 
            +
                                text, label = line.split("$LABEL$")
         | 
| 98 | 
            +
                                text = text.strip()
         | 
| 99 | 
            +
                                label = int(label.strip())
         | 
| 100 | 
            +
                                data.append((text, label))
         | 
| 101 | 
            +
                                label_set.add(label)
         | 
| 102 | 
            +
                    return data[random.randint(0, len(data))]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def get_agnews_example():
         | 
| 106 | 
            +
                filter_key_words = [
         | 
| 107 | 
            +
                    ".py",
         | 
| 108 | 
            +
                    ".md",
         | 
| 109 | 
            +
                    "readme",
         | 
| 110 | 
            +
                    "log",
         | 
| 111 | 
            +
                    "result",
         | 
| 112 | 
            +
                    "zip",
         | 
| 113 | 
            +
                    ".state_dict",
         | 
| 114 | 
            +
                    ".model",
         | 
| 115 | 
            +
                    ".png",
         | 
| 116 | 
            +
                    "acc_",
         | 
| 117 | 
            +
                    "f1_",
         | 
| 118 | 
            +
                    ".origin",
         | 
| 119 | 
            +
                    ".adv",
         | 
| 120 | 
            +
                    ".csv",
         | 
| 121 | 
            +
                ]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 124 | 
            +
                dataset = "agnews"
         | 
| 125 | 
            +
                search_path = "./"
         | 
| 126 | 
            +
                task = "text_defense"
         | 
| 127 | 
            +
                dataset_file["test"] += find_files(
         | 
| 128 | 
            +
                    search_path,
         | 
| 129 | 
            +
                    [dataset, "test", task],
         | 
| 130 | 
            +
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 131 | 
            +
                                + filter_key_words,
         | 
| 132 | 
            +
                )
         | 
| 133 | 
            +
                for dat_type in ["test"]:
         | 
| 134 | 
            +
                    data = []
         | 
| 135 | 
            +
                    label_set = set()
         | 
| 136 | 
            +
                    for data_file in dataset_file[dat_type]:
         | 
| 137 | 
            +
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 138 | 
            +
                            lines = fin.readlines()
         | 
| 139 | 
            +
                            for line in lines:
         | 
| 140 | 
            +
                                text, label = line.split("$LABEL$")
         | 
| 141 | 
            +
                                text = text.strip()
         | 
| 142 | 
            +
                                label = int(label.strip())
         | 
| 143 | 
            +
                                data.append((text, label))
         | 
| 144 | 
            +
                                label_set.add(label)
         | 
| 145 | 
            +
                    return data[random.randint(0, len(data))]
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def get_amazon_example():
         | 
| 149 | 
            +
                filter_key_words = [
         | 
| 150 | 
            +
                    ".py",
         | 
| 151 | 
            +
                    ".md",
         | 
| 152 | 
            +
                    "readme",
         | 
| 153 | 
            +
                    "log",
         | 
| 154 | 
            +
                    "result",
         | 
| 155 | 
            +
                    "zip",
         | 
| 156 | 
            +
                    ".state_dict",
         | 
| 157 | 
            +
                    ".model",
         | 
| 158 | 
            +
                    ".png",
         | 
| 159 | 
            +
                    "acc_",
         | 
| 160 | 
            +
                    "f1_",
         | 
| 161 | 
            +
                    ".origin",
         | 
| 162 | 
            +
                    ".adv",
         | 
| 163 | 
            +
                    ".csv",
         | 
| 164 | 
            +
                ]
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 167 | 
            +
                dataset = "amazon"
         | 
| 168 | 
            +
                search_path = "./"
         | 
| 169 | 
            +
                task = "text_defense"
         | 
| 170 | 
            +
                dataset_file["test"] += find_files(
         | 
| 171 | 
            +
                    search_path,
         | 
| 172 | 
            +
                    [dataset, "test", task],
         | 
| 173 | 
            +
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 174 | 
            +
                                + filter_key_words,
         | 
| 175 | 
            +
                )
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                for dat_type in ["test"]:
         | 
| 178 | 
            +
                    data = []
         | 
| 179 | 
            +
                    label_set = set()
         | 
| 180 | 
            +
                    for data_file in dataset_file[dat_type]:
         | 
| 181 | 
            +
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 182 | 
            +
                            lines = fin.readlines()
         | 
| 183 | 
            +
                            for line in lines:
         | 
| 184 | 
            +
                                text, label = line.split("$LABEL$")
         | 
| 185 | 
            +
                                text = text.strip()
         | 
| 186 | 
            +
                                label = int(label.strip())
         | 
| 187 | 
            +
                                data.append((text, label))
         | 
| 188 | 
            +
                                label_set.add(label)
         | 
| 189 | 
            +
                    return data[random.randint(0, len(data))]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
            def get_imdb_example():
         | 
| 193 | 
            +
                filter_key_words = [
         | 
| 194 | 
            +
                    ".py",
         | 
| 195 | 
            +
                    ".md",
         | 
| 196 | 
            +
                    "readme",
         | 
| 197 | 
            +
                    "log",
         | 
| 198 | 
            +
                    "result",
         | 
| 199 | 
            +
                    "zip",
         | 
| 200 | 
            +
                    ".state_dict",
         | 
| 201 | 
            +
                    ".model",
         | 
| 202 | 
            +
                    ".png",
         | 
| 203 | 
            +
                    "acc_",
         | 
| 204 | 
            +
                    "f1_",
         | 
| 205 | 
            +
                    ".origin",
         | 
| 206 | 
            +
                    ".adv",
         | 
| 207 | 
            +
                    ".csv",
         | 
| 208 | 
            +
                ]
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                dataset_file = {"train": [], "test": [], "valid": []}
         | 
| 211 | 
            +
                dataset = "imdb"
         | 
| 212 | 
            +
                search_path = "./"
         | 
| 213 | 
            +
                task = "text_defense"
         | 
| 214 | 
            +
                dataset_file["test"] += find_files(
         | 
| 215 | 
            +
                    search_path,
         | 
| 216 | 
            +
                    [dataset, "test", task],
         | 
| 217 | 
            +
                    exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
         | 
| 218 | 
            +
                                + filter_key_words,
         | 
| 219 | 
            +
                )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                for dat_type in ["test"]:
         | 
| 222 | 
            +
                    data = []
         | 
| 223 | 
            +
                    label_set = set()
         | 
| 224 | 
            +
                    for data_file in dataset_file[dat_type]:
         | 
| 225 | 
            +
                        with open(data_file, mode="r", encoding="utf8") as fin:
         | 
| 226 | 
            +
                            lines = fin.readlines()
         | 
| 227 | 
            +
                            for line in lines:
         | 
| 228 | 
            +
                                text, label = line.split("$LABEL$")
         | 
| 229 | 
            +
                                text = text.strip()
         | 
| 230 | 
            +
                                label = int(label.strip())
         | 
| 231 | 
            +
                                data.append((text, label))
         | 
| 232 | 
            +
                                label_set.add(label)
         | 
| 233 | 
            +
                    return data[random.randint(0, len(data))]
         | 
| 234 | 
            +
             |