Spaces:
Running
Running
feat: support meta SSL watermarking
Browse files- SSL_watermark.py +87 -0
- app.py +23 -6
- dino_r50.pth +3 -0
- image_utils.py +80 -0
- out2048.pth +3 -0
- requirements.txt +4 -0
- torch_utils.py +84 -0
SSL_watermark.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
import torch_utils
|
8 |
+
import image_utils
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
torch.manual_seed(0)
|
13 |
+
np.random.seed(0)
|
14 |
+
|
15 |
+
print('Building backbone and normalization layer...')
|
16 |
+
backbone = torch_utils.build_backbone(path='dino_r50.pth')
|
17 |
+
normlayer = torch_utils.load_normalization_layer(path='out2048.pth')
|
18 |
+
model = torch_utils.NormLayerWrapper(backbone, normlayer)
|
19 |
+
|
20 |
+
print('Building the hypercone...')
|
21 |
+
FPR = 1e-6
|
22 |
+
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
|
23 |
+
rho = 1 + np.tan(angle)**2
|
24 |
+
carrier = torch.randn(1, 2048)
|
25 |
+
carrier /= torch.norm(carrier, dim=1, keepdim=True)
|
26 |
+
|
27 |
+
default_transform = transforms.Compose([
|
28 |
+
transforms.ToTensor(),
|
29 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
30 |
+
])
|
31 |
+
|
32 |
+
def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
|
33 |
+
img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
|
34 |
+
img = img_orig.clone().to(device, non_blocking=True)
|
35 |
+
img.requires_grad = True
|
36 |
+
optimizer = torch.optim.Adam([img], lr=1e-2)
|
37 |
+
|
38 |
+
for iteration in range(epochs):
|
39 |
+
print(f'iteration: {iteration}')
|
40 |
+
x = image_utils.ssim_attenuation(img, img_orig)
|
41 |
+
x = image_utils.psnr_clip(x, img_orig, psnr)
|
42 |
+
|
43 |
+
ft = model(x) # BxCxWxH -> BxD
|
44 |
+
|
45 |
+
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
|
46 |
+
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
|
47 |
+
cosines = torch.abs(dot_product/norm)
|
48 |
+
log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
|
49 |
+
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
|
50 |
+
|
51 |
+
loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
|
52 |
+
loss = lambda_w*loss_R + lambda_i*loss_l2_img
|
53 |
+
|
54 |
+
optimizer.zero_grad()
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
|
58 |
+
logs = {
|
59 |
+
"keyword": "img_optim",
|
60 |
+
"iteration": iteration,
|
61 |
+
"loss": loss.item(),
|
62 |
+
"loss_R": loss_R.item(),
|
63 |
+
"loss_l2_img": loss_l2_img.item(),
|
64 |
+
"log10_pvalue": log10_pvalue.item(),
|
65 |
+
}
|
66 |
+
print("__log__:%s" % json.dumps(logs))
|
67 |
+
|
68 |
+
img = image_utils.ssim_attenuation(img, img_orig)
|
69 |
+
img = image_utils.psnr_clip(img, img_orig, psnr)
|
70 |
+
img = image_utils.round_pixel(img)
|
71 |
+
img = img.squeeze(0).detach().cpu()
|
72 |
+
img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0))
|
73 |
+
|
74 |
+
return img
|
75 |
+
|
76 |
+
def decode(image):
|
77 |
+
img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
|
78 |
+
ft = model(img) # BxCxWxH -> BxD
|
79 |
+
|
80 |
+
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
|
81 |
+
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
|
82 |
+
cosines = torch.abs(dot_product/norm)
|
83 |
+
log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
|
84 |
+
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
|
85 |
+
|
86 |
+
text_marked = "marked" if loss_R < 0 else "unmarked"
|
87 |
+
return f'Image is {text_marked}, with p-value={10**log10_pvalue}'
|
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from steganography import Steganography
|
3 |
from utils import draw_multiple_line_text, generate_qr_code
|
|
|
4 |
|
5 |
|
6 |
TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
|
@@ -8,20 +9,27 @@ TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
|
|
8 |
|
9 |
def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
|
10 |
input_image = input_image.convert('RGB')
|
11 |
-
|
12 |
if radio_button == "Image":
|
13 |
watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
|
14 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
15 |
elif radio_button == "Text":
|
16 |
watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
|
17 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
18 |
-
|
19 |
size = min(input_image.width, input_image.height)
|
20 |
watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
|
21 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
|
|
|
|
|
|
22 |
|
23 |
-
def extract_watermark(input_image_to_extract):
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
with gr.Blocks() as demo:
|
@@ -34,7 +42,7 @@ with gr.Blocks() as demo:
|
|
34 |
with gr.Blocks():
|
35 |
gr.Markdown("### Which type of watermark you want to apply?")
|
36 |
radio_button = gr.Radio(
|
37 |
-
choices=["QRCode", "Text", "Image"],
|
38 |
label="Watermark type",
|
39 |
value="QRCode",
|
40 |
# info="Which type of watermark you want to apply?"
|
@@ -82,6 +90,11 @@ with gr.Blocks() as demo:
|
|
82 |
with gr.Column():
|
83 |
gr.Markdown("### Image to extract watermark")
|
84 |
input_image_to_extract = gr.Image(type='pil')
|
|
|
|
|
|
|
|
|
|
|
85 |
with gr.Column():
|
86 |
gr.Markdown("### Extracted watermark")
|
87 |
extracted_watermark = gr.Image(type='pil')
|
@@ -97,6 +110,10 @@ with gr.Blocks() as demo:
|
|
97 |
inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url],
|
98 |
outputs=[output_image]
|
99 |
)
|
100 |
-
extract_button.click(
|
|
|
|
|
|
|
|
|
101 |
|
102 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from steganography import Steganography
|
3 |
from utils import draw_multiple_line_text, generate_qr_code
|
4 |
+
from SSL_watermark import encode, decode
|
5 |
|
6 |
|
7 |
TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
|
|
|
9 |
|
10 |
def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
|
11 |
input_image = input_image.convert('RGB')
|
12 |
+
print(f'radio_button: {radio_button}')
|
13 |
if radio_button == "Image":
|
14 |
watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
|
15 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
16 |
elif radio_button == "Text":
|
17 |
watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
|
18 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
19 |
+
elif radio_button == "QRCode":
|
20 |
size = min(input_image.width, input_image.height)
|
21 |
watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
|
22 |
return Steganography().merge(input_image, watermark_image, digit=7)
|
23 |
+
else:
|
24 |
+
print('start encoding ssl watermark...')
|
25 |
+
return encode(input_image, epochs=5)
|
26 |
|
27 |
+
def extract_watermark(extract_radio_button, input_image_to_extract):
|
28 |
+
if extract_radio_button == 'Steganography':
|
29 |
+
return Steganography().unmerge(input_image_to_extract.convert('RGB'), digit=7).convert('RGBA')
|
30 |
+
else:
|
31 |
+
decoded_info = decode(image=input_image_to_extract)
|
32 |
+
return draw_multiple_line_text(input_image_size=input_image_to_extract.size, text=decoded_info)
|
33 |
|
34 |
|
35 |
with gr.Blocks() as demo:
|
|
|
42 |
with gr.Blocks():
|
43 |
gr.Markdown("### Which type of watermark you want to apply?")
|
44 |
radio_button = gr.Radio(
|
45 |
+
choices=["QRCode", "Text", "Image", "SSL Watermark"],
|
46 |
label="Watermark type",
|
47 |
value="QRCode",
|
48 |
# info="Which type of watermark you want to apply?"
|
|
|
90 |
with gr.Column():
|
91 |
gr.Markdown("### Image to extract watermark")
|
92 |
input_image_to_extract = gr.Image(type='pil')
|
93 |
+
extract_radio_button = gr.Radio(
|
94 |
+
choices=["Steganography", "SSL Watermark"],
|
95 |
+
label="Extract methods",
|
96 |
+
value="Steganography"
|
97 |
+
)
|
98 |
with gr.Column():
|
99 |
gr.Markdown("### Extracted watermark")
|
100 |
extracted_watermark = gr.Image(type='pil')
|
|
|
110 |
inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url],
|
111 |
outputs=[output_image]
|
112 |
)
|
113 |
+
extract_button.click(
|
114 |
+
fn=extract_watermark,
|
115 |
+
inputs=[extract_radio_button, input_image_to_extract],
|
116 |
+
outputs=[extracted_watermark]
|
117 |
+
)
|
118 |
|
119 |
demo.launch()
|
dino_r50.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab26d85d00cb1be8e757cf8820cf0fd8aa729ea7e21b1cf6c44875952ba8eb0f
|
3 |
+
size 788803344
|
image_utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from torch.autograd.variable import Variable
|
9 |
+
|
10 |
+
NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
11 |
+
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
|
14 |
+
image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
|
15 |
+
|
16 |
+
def normalize_img(x):
|
17 |
+
return (x.to(device) - image_mean) / image_std
|
18 |
+
|
19 |
+
def unnormalize_img(x):
|
20 |
+
return (x.to(device) * image_std) + image_mean
|
21 |
+
|
22 |
+
def round_pixel(x):
|
23 |
+
x_pixel = 255 * unnormalize_img(x)
|
24 |
+
y = torch.round(x_pixel).clamp(0, 255)
|
25 |
+
y = normalize_img(y/255.0)
|
26 |
+
return y
|
27 |
+
|
28 |
+
def project_linf(x, y, radius):
|
29 |
+
""" Clamp x-y so that Linf(x,y)<=radius """
|
30 |
+
delta = x - y
|
31 |
+
delta = 255 * (delta * image_std)
|
32 |
+
delta = torch.clamp(delta, -radius, radius)
|
33 |
+
delta = (delta / 255.0) / image_std
|
34 |
+
return y + delta
|
35 |
+
|
36 |
+
def psnr_clip(x, y, target_psnr):
|
37 |
+
""" Clip x-y so that PSNR(x,y)=target_psnr """
|
38 |
+
delta = x - y
|
39 |
+
delta = 255 * (delta * image_std)
|
40 |
+
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
|
41 |
+
if psnr<target_psnr:
|
42 |
+
delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta
|
43 |
+
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
|
44 |
+
delta = (delta / 255.0) / image_std
|
45 |
+
return y + delta
|
46 |
+
|
47 |
+
def ssim_heatmap(img1, img2, window_size):
|
48 |
+
""" Compute the SSIM heatmap between 2 images """
|
49 |
+
_1D_window = torch.Tensor(
|
50 |
+
[np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
|
51 |
+
).to(device, non_blocking=True)
|
52 |
+
_1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
|
53 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
54 |
+
window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
|
55 |
+
|
56 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
|
57 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
|
58 |
+
|
59 |
+
mu1_sq = mu1.pow(2)
|
60 |
+
mu2_sq = mu2.pow(2)
|
61 |
+
mu1_mu2 = mu1*mu2
|
62 |
+
|
63 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
|
64 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
|
65 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
|
66 |
+
|
67 |
+
C1 = 0.01**2
|
68 |
+
C2 = 0.03**2
|
69 |
+
|
70 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
71 |
+
return ssim_map
|
72 |
+
|
73 |
+
def ssim_attenuation(x, y):
|
74 |
+
""" attenuate x-y using SSIM heatmap """
|
75 |
+
delta = x - y
|
76 |
+
ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
|
77 |
+
ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
|
78 |
+
ssim_map = torch.clamp_min(ssim_map,0)
|
79 |
+
delta = delta*ssim_map
|
80 |
+
return y + delta
|
out2048.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b256188454d8f7cf440de048df398e2a3209136a52cd7cdac834f5792f526a3
|
3 |
+
size 16786561
|
requirements.txt
CHANGED
@@ -1,4 +1,8 @@
|
|
|
|
|
|
1 |
Pillow
|
2 |
click
|
3 |
gradio
|
4 |
qrcode
|
|
|
|
|
|
1 |
+
torch==1.10.1
|
2 |
+
torchvision==0.11.2
|
3 |
Pillow
|
4 |
click
|
5 |
gradio
|
6 |
qrcode
|
7 |
+
scipy
|
8 |
+
json
|
torch_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
|
7 |
+
from scipy.optimize import root_scalar
|
8 |
+
from scipy.special import betainc
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
def build_backbone(path, name='resnet50'):
|
13 |
+
""" Builds a pretrained ResNet-50 backbone. """
|
14 |
+
model = getattr(models, name)(pretrained=False)
|
15 |
+
model.head = nn.Identity()
|
16 |
+
model.fc = nn.Identity()
|
17 |
+
checkpoint = torch.load(path, map_location=device)
|
18 |
+
state_dict = checkpoint
|
19 |
+
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
|
20 |
+
if ckpt_key in checkpoint:
|
21 |
+
state_dict = checkpoint[ckpt_key]
|
22 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
23 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
24 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
25 |
+
return model
|
26 |
+
|
27 |
+
def get_linear_layer(weight, bias):
|
28 |
+
""" Creates a layer that performs feature whitening or centering """
|
29 |
+
dim_out, dim_in = weight.shape
|
30 |
+
layer = nn.Linear(dim_in, dim_out)
|
31 |
+
layer.weight = nn.Parameter(weight)
|
32 |
+
layer.bias = nn.Parameter(bias)
|
33 |
+
return layer
|
34 |
+
|
35 |
+
def load_normalization_layer(path):
|
36 |
+
"""
|
37 |
+
Loads the normalization layer from a checkpoint and returns the layer.
|
38 |
+
"""
|
39 |
+
checkpoint = torch.load(path, map_location=device)
|
40 |
+
if 'whitening' in path or 'out' in path:
|
41 |
+
D = checkpoint['weight'].shape[1]
|
42 |
+
weight = torch.nn.Parameter(D*checkpoint['weight'])
|
43 |
+
bias = torch.nn.Parameter(D*checkpoint['bias'])
|
44 |
+
else:
|
45 |
+
weight = checkpoint['weight']
|
46 |
+
bias = checkpoint['bias']
|
47 |
+
return get_linear_layer(weight, bias).to(device, non_blocking=True)
|
48 |
+
|
49 |
+
class NormLayerWrapper(nn.Module):
|
50 |
+
"""
|
51 |
+
Wraps backbone model and normalization layer
|
52 |
+
"""
|
53 |
+
def __init__(self, backbone, head):
|
54 |
+
super(NormLayerWrapper, self).__init__()
|
55 |
+
backbone.eval(), head.eval()
|
56 |
+
self.backbone = backbone
|
57 |
+
self.head = head
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
output = self.backbone(x)
|
61 |
+
return self.head(output)
|
62 |
+
|
63 |
+
def cosine_pvalue(c, d, k=1):
|
64 |
+
"""
|
65 |
+
Returns the probability that the absolute value of the projection
|
66 |
+
between random unit vectors is higher than c
|
67 |
+
Args:
|
68 |
+
c: cosine value
|
69 |
+
d: dimension of the features
|
70 |
+
k: number of dimensions of the projection
|
71 |
+
"""
|
72 |
+
assert k>0
|
73 |
+
a = (d - k) / 2.0
|
74 |
+
b = k / 2.0
|
75 |
+
if c < 0:
|
76 |
+
return 1.0
|
77 |
+
return betainc(a, b, 1 - c ** 2)
|
78 |
+
|
79 |
+
def pvalue_angle(dim, k=1, angle=None, proba=None):
|
80 |
+
def f(a):
|
81 |
+
return cosine_pvalue(np.cos(a), dim, k) - proba
|
82 |
+
a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
|
83 |
+
# a = fsolve(f, x0=0.49*np.pi)[0]
|
84 |
+
return a.root
|