File size: 2,536 Bytes
bd0a3d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
from huggingface_hub import PyTorchModelHubMixin
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image

from model import ICN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def mask_processing(x):
    if x > 90:
        return 140
    elif x < 80:
        return 0
    else:
        return 255


def grid_to_heatmap(grid, size=1024):
    mask = to_pil_image(grid.view(7, 7))
    mask = mask.resize((size, size), Image.BICUBIC)
    mask = Image.eval(mask, mask_processing)

    colormap = plt.get_cmap("Wistia")
    heatmap = np.array(colormap(mask))
    heatmap = (heatmap * 255).astype(np.uint8)
    heatmap = Image.fromarray(heatmap)

    return heatmap, mask


def summary_image(img, fake, prediction):
    prediction -= prediction.min()
    prediction = prediction / prediction.max()

    size = 1024

    img1 = img.resize((size, size))
    img2 = fake.resize((size, size))

    heatmap, mask = grid_to_heatmap(prediction)
    img1.paste(heatmap, (0, 0), mask)
    img2.paste(heatmap, (0, 0), mask)

    return img1, img2


@st.cache_resource
def load_model():
    model = torch.jit.load("traced_model.pt")
    model.eval().to(device)
    return model


model = ICN.from_pretrained("AlexBlck/image-comparator").eval().to(device)

# model = load_model()

st.title("Image Comparator Network")

st.write("## Upload a pair of images")
cols = st.columns(2)
with cols[0]:
    im1 = st.file_uploader("Image 1", type=["jpg", "png"])
with cols[1]:
    im2 = st.file_uploader("Image 2", type=["jpg", "png"])

if not (im1 and im2):
    st.stop()

btn = st.button("Run")
if not btn:
    st.stop()

im1 = Image.open(im1).convert("RGB")
im2 = Image.open(im2).convert("RGB")

tr = transforms.Compose(
    [
        transforms.Resize(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

img = torch.vstack((tr(im1), tr(im2))).unsqueeze(0)

heatmap, cl = model(img.to(device))
confs = torch.softmax(cl, dim=1)
pred = torch.argmax(confs, dim=1).item()

if pred == 0:
    st.success("No Manipulation Detected")
    heatmap *= 0
elif pred == 1:
    st.warning("Manipulation Detected!")
else:
    st.error("Images are not related.")
    heatmap *= 0

img1, img2 = summary_image(im1, im2, heatmap[0])
cols = st.columns(2)
with cols[0]:
    st.image(img1)
with cols[1]:
    st.image(img2)