AnsenH commited on
Commit
a68d5f1
·
1 Parent(s): 41258b3

feat: support meta SSL watermarking

Browse files
Files changed (7) hide show
  1. SSL_watermark.py +87 -0
  2. app.py +23 -6
  3. dino_r50.pth +3 -0
  4. image_utils.py +80 -0
  5. out2048.pth +3 -0
  6. requirements.txt +4 -0
  7. torch_utils.py +84 -0
SSL_watermark.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+ import torch_utils
8
+ import image_utils
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ torch.manual_seed(0)
13
+ np.random.seed(0)
14
+
15
+ print('Building backbone and normalization layer...')
16
+ backbone = torch_utils.build_backbone(path='dino_r50.pth')
17
+ normlayer = torch_utils.load_normalization_layer(path='out2048.pth')
18
+ model = torch_utils.NormLayerWrapper(backbone, normlayer)
19
+
20
+ print('Building the hypercone...')
21
+ FPR = 1e-6
22
+ angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
23
+ rho = 1 + np.tan(angle)**2
24
+ carrier = torch.randn(1, 2048)
25
+ carrier /= torch.norm(carrier, dim=1, keepdim=True)
26
+
27
+ default_transform = transforms.Compose([
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
33
+ img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
34
+ img = img_orig.clone().to(device, non_blocking=True)
35
+ img.requires_grad = True
36
+ optimizer = torch.optim.Adam([img], lr=1e-2)
37
+
38
+ for iteration in range(epochs):
39
+ print(f'iteration: {iteration}')
40
+ x = image_utils.ssim_attenuation(img, img_orig)
41
+ x = image_utils.psnr_clip(x, img_orig, psnr)
42
+
43
+ ft = model(x) # BxCxWxH -> BxD
44
+
45
+ dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
46
+ norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
47
+ cosines = torch.abs(dot_product/norm)
48
+ log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
49
+ loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
50
+
51
+ loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
52
+ loss = lambda_w*loss_R + lambda_i*loss_l2_img
53
+
54
+ optimizer.zero_grad()
55
+ loss.backward()
56
+ optimizer.step()
57
+
58
+ logs = {
59
+ "keyword": "img_optim",
60
+ "iteration": iteration,
61
+ "loss": loss.item(),
62
+ "loss_R": loss_R.item(),
63
+ "loss_l2_img": loss_l2_img.item(),
64
+ "log10_pvalue": log10_pvalue.item(),
65
+ }
66
+ print("__log__:%s" % json.dumps(logs))
67
+
68
+ img = image_utils.ssim_attenuation(img, img_orig)
69
+ img = image_utils.psnr_clip(img, img_orig, psnr)
70
+ img = image_utils.round_pixel(img)
71
+ img = img.squeeze(0).detach().cpu()
72
+ img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0))
73
+
74
+ return img
75
+
76
+ def decode(image):
77
+ img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
78
+ ft = model(img) # BxCxWxH -> BxD
79
+
80
+ dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
81
+ norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
82
+ cosines = torch.abs(dot_product/norm)
83
+ log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
84
+ loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
85
+
86
+ text_marked = "marked" if loss_R < 0 else "unmarked"
87
+ return f'Image is {text_marked}, with p-value={10**log10_pvalue}'
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from steganography import Steganography
3
  from utils import draw_multiple_line_text, generate_qr_code
 
4
 
5
 
6
  TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
@@ -8,20 +9,27 @@ TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
8
 
9
  def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
10
  input_image = input_image.convert('RGB')
11
-
12
  if radio_button == "Image":
13
  watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
14
  return Steganography().merge(input_image, watermark_image, digit=7)
15
  elif radio_button == "Text":
16
  watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
17
  return Steganography().merge(input_image, watermark_image, digit=7)
18
- else:
19
  size = min(input_image.width, input_image.height)
20
  watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
21
  return Steganography().merge(input_image, watermark_image, digit=7)
 
 
 
22
 
23
- def extract_watermark(input_image_to_extract):
24
- return Steganography().unmerge(input_image_to_extract.convert('RGB'), digit=7).convert('RGBA')
 
 
 
 
25
 
26
 
27
  with gr.Blocks() as demo:
@@ -34,7 +42,7 @@ with gr.Blocks() as demo:
34
  with gr.Blocks():
35
  gr.Markdown("### Which type of watermark you want to apply?")
36
  radio_button = gr.Radio(
37
- choices=["QRCode", "Text", "Image"],
38
  label="Watermark type",
39
  value="QRCode",
40
  # info="Which type of watermark you want to apply?"
@@ -82,6 +90,11 @@ with gr.Blocks() as demo:
82
  with gr.Column():
83
  gr.Markdown("### Image to extract watermark")
84
  input_image_to_extract = gr.Image(type='pil')
 
 
 
 
 
85
  with gr.Column():
86
  gr.Markdown("### Extracted watermark")
87
  extracted_watermark = gr.Image(type='pil')
@@ -97,6 +110,10 @@ with gr.Blocks() as demo:
97
  inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url],
98
  outputs=[output_image]
99
  )
100
- extract_button.click(fn=extract_watermark, inputs=[input_image_to_extract], outputs=[extracted_watermark])
 
 
 
 
101
 
102
  demo.launch()
 
1
  import gradio as gr
2
  from steganography import Steganography
3
  from utils import draw_multiple_line_text, generate_qr_code
4
+ from SSL_watermark import encode, decode
5
 
6
 
7
  TITLE = """<h2 align="center"> ✍️ Invisible Watermark </h2>"""
 
9
 
10
  def apply_watermark(radio_button, input_image, watermark_image, watermark_text, watermark_url):
11
  input_image = input_image.convert('RGB')
12
+ print(f'radio_button: {radio_button}')
13
  if radio_button == "Image":
14
  watermark_image = watermark_image.resize((input_image.width, input_image.height)).convert('L').convert('RGB')
15
  return Steganography().merge(input_image, watermark_image, digit=7)
16
  elif radio_button == "Text":
17
  watermark_image = draw_multiple_line_text(input_image.size, watermark_text)
18
  return Steganography().merge(input_image, watermark_image, digit=7)
19
+ elif radio_button == "QRCode":
20
  size = min(input_image.width, input_image.height)
21
  watermark_image = generate_qr_code(watermark_url).resize((size, size)).convert('RGB')
22
  return Steganography().merge(input_image, watermark_image, digit=7)
23
+ else:
24
+ print('start encoding ssl watermark...')
25
+ return encode(input_image, epochs=5)
26
 
27
+ def extract_watermark(extract_radio_button, input_image_to_extract):
28
+ if extract_radio_button == 'Steganography':
29
+ return Steganography().unmerge(input_image_to_extract.convert('RGB'), digit=7).convert('RGBA')
30
+ else:
31
+ decoded_info = decode(image=input_image_to_extract)
32
+ return draw_multiple_line_text(input_image_size=input_image_to_extract.size, text=decoded_info)
33
 
34
 
35
  with gr.Blocks() as demo:
 
42
  with gr.Blocks():
43
  gr.Markdown("### Which type of watermark you want to apply?")
44
  radio_button = gr.Radio(
45
+ choices=["QRCode", "Text", "Image", "SSL Watermark"],
46
  label="Watermark type",
47
  value="QRCode",
48
  # info="Which type of watermark you want to apply?"
 
90
  with gr.Column():
91
  gr.Markdown("### Image to extract watermark")
92
  input_image_to_extract = gr.Image(type='pil')
93
+ extract_radio_button = gr.Radio(
94
+ choices=["Steganography", "SSL Watermark"],
95
+ label="Extract methods",
96
+ value="Steganography"
97
+ )
98
  with gr.Column():
99
  gr.Markdown("### Extracted watermark")
100
  extracted_watermark = gr.Image(type='pil')
 
110
  inputs=[radio_button, input_image, watermark_image, watermark_text, watermark_url],
111
  outputs=[output_image]
112
  )
113
+ extract_button.click(
114
+ fn=extract_watermark,
115
+ inputs=[extract_radio_button, input_image_to_extract],
116
+ outputs=[extracted_watermark]
117
+ )
118
 
119
  demo.launch()
dino_r50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab26d85d00cb1be8e757cf8820cf0fd8aa729ea7e21b1cf6c44875952ba8eb0f
3
+ size 788803344
image_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ from torchvision import transforms
5
+
6
+ import torch.nn.functional as F
7
+
8
+ from torch.autograd.variable import Variable
9
+
10
+ NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
14
+ image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
15
+
16
+ def normalize_img(x):
17
+ return (x.to(device) - image_mean) / image_std
18
+
19
+ def unnormalize_img(x):
20
+ return (x.to(device) * image_std) + image_mean
21
+
22
+ def round_pixel(x):
23
+ x_pixel = 255 * unnormalize_img(x)
24
+ y = torch.round(x_pixel).clamp(0, 255)
25
+ y = normalize_img(y/255.0)
26
+ return y
27
+
28
+ def project_linf(x, y, radius):
29
+ """ Clamp x-y so that Linf(x,y)<=radius """
30
+ delta = x - y
31
+ delta = 255 * (delta * image_std)
32
+ delta = torch.clamp(delta, -radius, radius)
33
+ delta = (delta / 255.0) / image_std
34
+ return y + delta
35
+
36
+ def psnr_clip(x, y, target_psnr):
37
+ """ Clip x-y so that PSNR(x,y)=target_psnr """
38
+ delta = x - y
39
+ delta = 255 * (delta * image_std)
40
+ psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
41
+ if psnr<target_psnr:
42
+ delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta
43
+ psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
44
+ delta = (delta / 255.0) / image_std
45
+ return y + delta
46
+
47
+ def ssim_heatmap(img1, img2, window_size):
48
+ """ Compute the SSIM heatmap between 2 images """
49
+ _1D_window = torch.Tensor(
50
+ [np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
51
+ ).to(device, non_blocking=True)
52
+ _1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
53
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
54
+ window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
55
+
56
+ mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
57
+ mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
58
+
59
+ mu1_sq = mu1.pow(2)
60
+ mu2_sq = mu2.pow(2)
61
+ mu1_mu2 = mu1*mu2
62
+
63
+ sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
64
+ sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
65
+ sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
66
+
67
+ C1 = 0.01**2
68
+ C2 = 0.03**2
69
+
70
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
71
+ return ssim_map
72
+
73
+ def ssim_attenuation(x, y):
74
+ """ attenuate x-y using SSIM heatmap """
75
+ delta = x - y
76
+ ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
77
+ ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
78
+ ssim_map = torch.clamp_min(ssim_map,0)
79
+ delta = delta*ssim_map
80
+ return y + delta
out2048.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b256188454d8f7cf440de048df398e2a3209136a52cd7cdac834f5792f526a3
3
+ size 16786561
requirements.txt CHANGED
@@ -1,4 +1,8 @@
 
 
1
  Pillow
2
  click
3
  gradio
4
  qrcode
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
  Pillow
4
  click
5
  gradio
6
  qrcode
7
+ scipy
8
+ json
torch_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+
7
+ from scipy.optimize import root_scalar
8
+ from scipy.special import betainc
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ def build_backbone(path, name='resnet50'):
13
+ """ Builds a pretrained ResNet-50 backbone. """
14
+ model = getattr(models, name)(pretrained=False)
15
+ model.head = nn.Identity()
16
+ model.fc = nn.Identity()
17
+ checkpoint = torch.load(path, map_location=device)
18
+ state_dict = checkpoint
19
+ for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
20
+ if ckpt_key in checkpoint:
21
+ state_dict = checkpoint[ckpt_key]
22
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
23
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
24
+ msg = model.load_state_dict(state_dict, strict=False)
25
+ return model
26
+
27
+ def get_linear_layer(weight, bias):
28
+ """ Creates a layer that performs feature whitening or centering """
29
+ dim_out, dim_in = weight.shape
30
+ layer = nn.Linear(dim_in, dim_out)
31
+ layer.weight = nn.Parameter(weight)
32
+ layer.bias = nn.Parameter(bias)
33
+ return layer
34
+
35
+ def load_normalization_layer(path):
36
+ """
37
+ Loads the normalization layer from a checkpoint and returns the layer.
38
+ """
39
+ checkpoint = torch.load(path, map_location=device)
40
+ if 'whitening' in path or 'out' in path:
41
+ D = checkpoint['weight'].shape[1]
42
+ weight = torch.nn.Parameter(D*checkpoint['weight'])
43
+ bias = torch.nn.Parameter(D*checkpoint['bias'])
44
+ else:
45
+ weight = checkpoint['weight']
46
+ bias = checkpoint['bias']
47
+ return get_linear_layer(weight, bias).to(device, non_blocking=True)
48
+
49
+ class NormLayerWrapper(nn.Module):
50
+ """
51
+ Wraps backbone model and normalization layer
52
+ """
53
+ def __init__(self, backbone, head):
54
+ super(NormLayerWrapper, self).__init__()
55
+ backbone.eval(), head.eval()
56
+ self.backbone = backbone
57
+ self.head = head
58
+
59
+ def forward(self, x):
60
+ output = self.backbone(x)
61
+ return self.head(output)
62
+
63
+ def cosine_pvalue(c, d, k=1):
64
+ """
65
+ Returns the probability that the absolute value of the projection
66
+ between random unit vectors is higher than c
67
+ Args:
68
+ c: cosine value
69
+ d: dimension of the features
70
+ k: number of dimensions of the projection
71
+ """
72
+ assert k>0
73
+ a = (d - k) / 2.0
74
+ b = k / 2.0
75
+ if c < 0:
76
+ return 1.0
77
+ return betainc(a, b, 1 - c ** 2)
78
+
79
+ def pvalue_angle(dim, k=1, angle=None, proba=None):
80
+ def f(a):
81
+ return cosine_pvalue(np.cos(a), dim, k) - proba
82
+ a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
83
+ # a = fsolve(f, x0=0.49*np.pi)[0]
84
+ return a.root