Invisible_Watermark / SSL_watermark.py
AnsenH's picture
chore: update layout
21ed458
import numpy as np
import json
import torch
from torchvision import transforms
import util_functions.torch_utils as torch_utils
import util_functions.image_utils as image_utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)
print('Building backbone and normalization layer...')
backbone = torch_utils.build_backbone(path='models/dino_r50.pth')
normlayer = torch_utils.load_normalization_layer(path='models/out2048.pth')
model = torch_utils.NormLayerWrapper(backbone, normlayer)
print('Building the hypercone...')
FPR = 1e-6
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
rho = 1 + np.tan(angle)**2
carrier = torch.randn(1, 2048)
carrier /= torch.norm(carrier, dim=1, keepdim=True)
default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
img = img_orig.clone().to(device, non_blocking=True)
img.requires_grad = True
optimizer = torch.optim.Adam([img], lr=1e-2)
for iteration in range(epochs):
print(f'iteration: {iteration}')
x = image_utils.ssim_attenuation(img, img_orig)
x = image_utils.psnr_clip(x, img_orig, psnr)
ft = model(x) # BxCxWxH -> BxD
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
cosines = torch.abs(dot_product/norm)
log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
loss = lambda_w*loss_R + lambda_i*loss_l2_img
optimizer.zero_grad()
loss.backward()
optimizer.step()
logs = {
"keyword": "img_optim",
"iteration": iteration,
"loss": loss.item(),
"loss_R": loss_R.item(),
"loss_l2_img": loss_l2_img.item(),
"log10_pvalue": log10_pvalue.item(),
}
print("__log__:%s" % json.dumps(logs))
img = image_utils.ssim_attenuation(img, img_orig)
img = image_utils.psnr_clip(img, img_orig, psnr)
img = image_utils.round_pixel(img)
img = img.squeeze(0).detach().cpu()
img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0))
return img
def decode(image):
img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
ft = model(img) # BxCxWxH -> BxD
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
cosines = torch.abs(dot_product/norm)
log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
text_marked = "marked" if loss_R < 0 else "unmarked"
return f'Image is {text_marked}'