File size: 3,075 Bytes
a68d5f1
 
 
 
 
 
1cd2670
 
a68d5f1
 
 
 
 
 
 
1cd2670
 
a68d5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21ed458
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import json

import torch
from torchvision import transforms

import util_functions.torch_utils as torch_utils 
import util_functions.image_utils as image_utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
np.random.seed(0)

print('Building backbone and normalization layer...')
backbone = torch_utils.build_backbone(path='models/dino_r50.pth')
normlayer = torch_utils.load_normalization_layer(path='models/out2048.pth')
model = torch_utils.NormLayerWrapper(backbone, normlayer)

print('Building the hypercone...')
FPR = 1e-6
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
rho = 1 + np.tan(angle)**2
carrier = torch.randn(1, 2048)
carrier /= torch.norm(carrier, dim=1, keepdim=True)

default_transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
    img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
    img = img_orig.clone().to(device, non_blocking=True) 
    img.requires_grad = True
    optimizer = torch.optim.Adam([img], lr=1e-2)

    for iteration in range(epochs):
        print(f'iteration: {iteration}')
        x = image_utils.ssim_attenuation(img, img_orig)
        x = image_utils.psnr_clip(x, img_orig, psnr)

        ft = model(x) # BxCxWxH -> BxD

        dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
        norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
        cosines = torch.abs(dot_product/norm)
        log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
        loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B

        loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
        loss = lambda_w*loss_R + lambda_i*loss_l2_img
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logs = {
            "keyword": "img_optim",
            "iteration": iteration,
            "loss": loss.item(),
            "loss_R": loss_R.item(),
            "loss_l2_img": loss_l2_img.item(),
            "log10_pvalue": log10_pvalue.item(),
        }
        print("__log__:%s" % json.dumps(logs))

    img = image_utils.ssim_attenuation(img, img_orig)
    img = image_utils.psnr_clip(img, img_orig, psnr)
    img = image_utils.round_pixel(img)
    img = img.squeeze(0).detach().cpu()
    img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0))

    return img

def decode(image):
    img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
    ft = model(img) # BxCxWxH -> BxD

    dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
    norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
    cosines = torch.abs(dot_product/norm)
    log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
    loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B

    text_marked = "marked" if loss_R < 0 else "unmarked"
    return f'Image is {text_marked}'