File size: 3,481 Bytes
999ff86
ebfeea8
3be5501
7ac9af3
3be5501
999ff86
7ac9af3
999ff86
3be5501
7ac9af3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3be5501
be8e821
3be5501
 
 
 
7ac9af3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from fawkes.protection import Fawkes
from fawkes.utils import Faces, reverse_process_cloaked
from fawkes.differentiator import FawkesMaskGeneration
from keras import utils
import numpy as np
import gradio as gr
import spaces

IMG_SIZE = 112
PREPROCESS = "raw"


def get_extractors():
    hash_map = {
        "extractor_2": "ce703d481db2b83513bbdafa27434703",
        "extractor_0": "94854151fd9077997d69ceda107f9c6b",
    }
    for key, value in hash_map.items():
        utils.get_file(
            fname="{}.h5".format(key),
            origin="http://mirror.cs.uchicago.edu/fawkes/files/{}.h5".format(key),
            md5_hash=value,
            cache_subdir="model",
        )


def generate_cloak_images(protector, image_X, target_emb=None):
    cloaked_image_X = protector.compute(image_X, target_emb)
    return cloaked_image_X


def preproc(img):
    img = img.convert("RGB")
    img = utils.img_to_array(img)
    return img


@spaces.GPU
def predict(
    img,
    level,
    sd=1e7,
    format="png",
    separate_target=True,
    debug=False,
    maximize=True,
    save_last_on_failed=True,
    progress=gr.Progress(track_tqdm=True),
):
    img = preproc(img)

    if level == "low":
        fwks = Fawkes("extractor_2", 1, mode="low")
    elif level == "mid":
        fwks = Fawkes("extractor_2", 1, mode="mid")
    elif level == "high":
        fwks = Fawkes("extractor_2", 1, mode="high")

    current_param = "-".join(
        [
            str(x)
            for x in [
                fwks.th,
                sd,
                fwks.lr,
                fwks.max_step,
                -1,
                format,
                separate_target,
                debug,
            ]
        ]
    )
    faces = Faces(["./Current Face"], [img], fwks.aligner, verbose=0, no_align=False)
    original_images = faces.cropped_faces

    if len(original_images) == 0:
        raise Exception("No face detected. ")
    original_images = np.array(original_images)

    if current_param != fwks.protector_param:
        fwks.protector_param = current_param
        if fwks.protector is not None:
            del fwks.protector
        batch_size = len(original_images)
        fwks.protector = FawkesMaskGeneration(
            fwks.feature_extractors_ls,
            batch_size=batch_size,
            mimic_img=True,
            intensity_range=PREPROCESS,
            initial_const=sd,
            learning_rate=fwks.lr,
            max_iterations=fwks.max_step,
            l_threshold=fwks.th,
            verbose=0,
            maximize=maximize,
            keep_final=False,
            image_shape=(IMG_SIZE, IMG_SIZE, 3),
            loss_method="features",
            tanh_process=True,
            save_last_on_failed=save_last_on_failed,
        )
    protected_images = generate_cloak_images(fwks.protector, original_images)
    faces.cloaked_cropped_faces = protected_images

    final_images, _ = faces.merge_faces(
        reverse_process_cloaked(protected_images, preprocess=PREPROCESS),
        reverse_process_cloaked(original_images, preprocess=PREPROCESS),
    )

    return final_images[-1].astype(np.uint8)


# Download extractors pre-emptively
get_extractors()

gr.Interface(
    fn=predict,
    inputs=[
        gr.components.Image(type="pil"),
        gr.components.Radio(["low", "mid", "high"], label="Protection Level"),
    ],
    outputs=gr.components.Image(type="pil"),
    allow_flagging="never",
).launch(show_error=True, quiet=False)