Spaces:
Runtime error
Runtime error
baixintech_zhangyiming_prod
commited on
Commit
·
7dd7207
1
Parent(s):
90e2119
init
Browse files- .gitignore +6 -0
- README.md +1 -1
- app.py +24 -0
- images/clean/3.png +0 -0
- images/watermark/1.png +0 -0
- images/watermark/2.png +0 -0
- requirements.txt +9 -0
- wmdetection/__init__.py +0 -0
- wmdetection/dataset/__init__.py +0 -0
- wmdetection/dataset/synthetic_wm.py +211 -0
- wmdetection/models/__init__.py +86 -0
- wmdetection/models/convnext.py +200 -0
- wmdetection/pipelines/__init__.py +0 -0
- wmdetection/pipelines/metrics.py +9 -0
- wmdetection/pipelines/predictor.py +73 -0
- wmdetection/utils/__init__.py +2 -0
- wmdetection/utils/files.py +26 -0
- wmdetection/utils/fp16module.py +64 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.ipynb_checkpoints/
|
3 |
+
dataset/*.csv
|
4 |
+
dataset/watermarks-validation/
|
5 |
+
weights/
|
6 |
+
model_files/
|
README.md
CHANGED
@@ -10,4 +10,4 @@ pinned: false
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from wmdetection.models import get_watermarks_detection_model
|
3 |
+
from wmdetection.pipelines.predictor import WatermarksPredictor
|
4 |
+
import os, glob
|
5 |
+
|
6 |
+
|
7 |
+
model, transforms = get_watermarks_detection_model(
|
8 |
+
'convnext-tiny',
|
9 |
+
fp16=False,
|
10 |
+
cache_dir='model_files'
|
11 |
+
)
|
12 |
+
predictor = WatermarksPredictor(model, transforms, 'cuda:0')
|
13 |
+
|
14 |
+
|
15 |
+
def predict(image):
|
16 |
+
result = predictor.predict_image(image)
|
17 |
+
return 'watermarked' if result else 'clean' # prints "watermarked"
|
18 |
+
|
19 |
+
|
20 |
+
examples = glob.glob(os.path.join('images', 'clean', '*'))
|
21 |
+
examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
|
22 |
+
iface = gr.Interface(fn=predict, inputs=[gr.inputs.Image(type="pil")],
|
23 |
+
examples=examples, outputs="text")
|
24 |
+
iface.launch()
|
images/clean/3.png
ADDED
![]() |
images/watermark/1.png
ADDED
![]() |
images/watermark/2.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pillow
|
4 |
+
numpy
|
5 |
+
matplotlib
|
6 |
+
tqdm
|
7 |
+
huggingface-hub
|
8 |
+
opencv-python
|
9 |
+
timm>=0.6.12
|
wmdetection/__init__.py
ADDED
File without changes
|
wmdetection/dataset/__init__.py
ADDED
File without changes
|
wmdetection/dataset/synthetic_wm.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import cv2
|
6 |
+
import string
|
7 |
+
import random
|
8 |
+
|
9 |
+
CV2_FONTS = [
|
10 |
+
#cv2.FONT_HERSHEY_COMPLEX,
|
11 |
+
cv2.FONT_HERSHEY_COMPLEX_SMALL,
|
12 |
+
cv2.FONT_HERSHEY_DUPLEX,
|
13 |
+
cv2.FONT_HERSHEY_PLAIN,
|
14 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
15 |
+
cv2.FONT_HERSHEY_TRIPLEX,
|
16 |
+
cv2.FONT_ITALIC,
|
17 |
+
cv2.QT_FONT_BLACK,
|
18 |
+
cv2.QT_FONT_NORMAL
|
19 |
+
]
|
20 |
+
|
21 |
+
# рандомный float между x и y
|
22 |
+
def random_float(x, y):
|
23 |
+
return random.random()*(y-x)+x
|
24 |
+
|
25 |
+
# вычисляет размер текста в пикселях для cv2.putText
|
26 |
+
def get_text_size(text, font, font_scale, thickness):
|
27 |
+
(w, h), baseline = cv2.getTextSize(text, font, font_scale, thickness)
|
28 |
+
return w, h+baseline
|
29 |
+
|
30 |
+
# вычисляет какой нужен font_scale для определенного размера текста (по высоте)
|
31 |
+
def get_font_scale(needed_height, text, font, thickness):
|
32 |
+
w, h = get_text_size(text, font, 1, thickness)
|
33 |
+
return needed_height/h
|
34 |
+
|
35 |
+
# добавляет текст на изображение
|
36 |
+
def place_text(image, text, color=(255,255,255), alpha=1, position=(0, 0), angle=0,
|
37 |
+
font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1.0, thickness=3):
|
38 |
+
image = np.array(image)
|
39 |
+
overlay = np.zeros_like(image)
|
40 |
+
output = image.copy()
|
41 |
+
|
42 |
+
cv2.putText(overlay, text, position, font, font_scale, color, thickness)
|
43 |
+
|
44 |
+
if angle != 0:
|
45 |
+
text_w, text_h = get_text_size(text, font, font_scale, thickness)
|
46 |
+
rotate_M = cv2.getRotationMatrix2D((position[0]+text_w//2, position[1]-text_h//2), angle, 1)
|
47 |
+
overlay = cv2.warpAffine(overlay, rotate_M, (overlay.shape[1], overlay.shape[0]))
|
48 |
+
|
49 |
+
overlay[overlay==0] = image[overlay==0]
|
50 |
+
cv2.addWeighted(overlay, alpha, output, 1-alpha, 0, output)
|
51 |
+
|
52 |
+
return Image.fromarray(output)
|
53 |
+
|
54 |
+
def get_random_font_params(text, text_height, fonts, font_thickness_range):
|
55 |
+
font = random.choice(fonts)
|
56 |
+
font_thickness_range_scaled = [int(font_thickness_range[0]*(text_height/35)),
|
57 |
+
int(font_thickness_range[1]*(text_height/85))]
|
58 |
+
try:
|
59 |
+
font_thickness = min(random.randint(*font_thickness_range_scaled), 2)
|
60 |
+
except ValueError:
|
61 |
+
font_thickness = 2
|
62 |
+
font_scale = get_font_scale(text_height, text, font, font_thickness)
|
63 |
+
return font, font_scale, font_thickness
|
64 |
+
|
65 |
+
# устанавливает вотермарку в центре изображения с рандомными параметрами
|
66 |
+
def place_random_centered_watermark(
|
67 |
+
pil_image,
|
68 |
+
text,
|
69 |
+
center_point_range_shift=(-0.025, 0.025),
|
70 |
+
random_angle=(0,0),
|
71 |
+
text_height_in_percent_range=(0.15, 0.18),
|
72 |
+
text_alpha_range=(0.23, 0.5),
|
73 |
+
fonts=CV2_FONTS,
|
74 |
+
font_thickness_range=(2, 7),
|
75 |
+
colors=[(255,255,255)]
|
76 |
+
):
|
77 |
+
w, h = pil_image.size
|
78 |
+
|
79 |
+
position_shift_x = random_float(*center_point_range_shift)
|
80 |
+
offset_x = int(w*position_shift_x)
|
81 |
+
position_shift_y = random_float(*center_point_range_shift)
|
82 |
+
offset_y = int(w*position_shift_y)
|
83 |
+
|
84 |
+
text_height = int(h*random_float(*text_height_in_percent_range))
|
85 |
+
|
86 |
+
font, font_scale, font_thickness = get_random_font_params(text, text_height, fonts, font_thickness_range)
|
87 |
+
|
88 |
+
text_width, _ = get_text_size(text, font, font_scale, font_thickness)
|
89 |
+
|
90 |
+
position_x = int((w/2)-text_width/2+offset_x)
|
91 |
+
position_y = int((h/2)+text_height/2+offset_y)
|
92 |
+
|
93 |
+
return place_text(
|
94 |
+
pil_image,
|
95 |
+
text,
|
96 |
+
color=random.choice(colors),
|
97 |
+
alpha=random_float(*text_alpha_range),
|
98 |
+
position=(position_x, position_y),
|
99 |
+
angle=random.randint(*random_angle),
|
100 |
+
thickness=font_thickness,
|
101 |
+
font=font,
|
102 |
+
font_scale=font_scale
|
103 |
+
)
|
104 |
+
|
105 |
+
def place_random_watermark(
|
106 |
+
pil_image,
|
107 |
+
text,
|
108 |
+
random_angle=(0,0),
|
109 |
+
text_height_in_percent_range=(0.10, 0.18),
|
110 |
+
text_alpha_range=(0.18, 0.4),
|
111 |
+
fonts=CV2_FONTS,
|
112 |
+
font_thickness_range=(2, 6),
|
113 |
+
colors=[(255,255,255)]
|
114 |
+
):
|
115 |
+
w, h = pil_image.size
|
116 |
+
|
117 |
+
text_height = int(h*random_float(*text_height_in_percent_range))
|
118 |
+
|
119 |
+
font, font_scale, font_thickness = get_random_font_params(text, text_height, fonts, font_thickness_range)
|
120 |
+
|
121 |
+
text_width, _ = get_text_size(text, font, font_scale, font_thickness)
|
122 |
+
|
123 |
+
position_x = random.randint(0, max(w-text_width, 10))
|
124 |
+
position_y = random.randint(text_height, h)
|
125 |
+
|
126 |
+
return place_text(
|
127 |
+
pil_image,
|
128 |
+
text,
|
129 |
+
color=random.choice(colors),
|
130 |
+
alpha=random_float(*text_alpha_range),
|
131 |
+
position=(position_x, position_y),
|
132 |
+
angle=random.randint(*random_angle),
|
133 |
+
thickness=font_thickness,
|
134 |
+
font=font,
|
135 |
+
font_scale=font_scale
|
136 |
+
)
|
137 |
+
|
138 |
+
def center_crop(image, w, h):
|
139 |
+
center = image.shape
|
140 |
+
x = center[1]/2 - w/2
|
141 |
+
y = center[0]/2 - h/2
|
142 |
+
return image[int(y):int(y+h), int(x):int(x+w)]
|
143 |
+
|
144 |
+
# добавляет текст в шахматном порядке на изображение
|
145 |
+
def place_text_checkerboard(image, text, color=(255,255,255), alpha=1, step_x=0.1, step_y=0.1, angle=0,
|
146 |
+
font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1.0, thickness=3):
|
147 |
+
image_size = image.size
|
148 |
+
|
149 |
+
image = np.array(image.convert('RGB'))
|
150 |
+
if angle != 0:
|
151 |
+
border_scale = 0.4
|
152 |
+
overlay_size = [int(i*(1+border_scale)) for i in list(image_size)]
|
153 |
+
else:
|
154 |
+
overlay_size = image_size
|
155 |
+
|
156 |
+
w, h = overlay_size
|
157 |
+
overlay = np.zeros((overlay_size[1], overlay_size[0], 3)) # change dimensions
|
158 |
+
output = image.copy()
|
159 |
+
|
160 |
+
text_w, text_h = get_text_size(text, font, font_scale, thickness)
|
161 |
+
|
162 |
+
c = 0
|
163 |
+
for rel_pos_x in np.arange(0, 1, step_x):
|
164 |
+
c += 1
|
165 |
+
for rel_pos_y in np.arange(text_h/h+(c%2)*step_y/2, 1, step_y):
|
166 |
+
position = (int(w*rel_pos_x), int(h*rel_pos_y))
|
167 |
+
cv2.putText(overlay, text, position, font, font_scale, color, thickness)
|
168 |
+
|
169 |
+
if angle != 0:
|
170 |
+
rotate_M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1)
|
171 |
+
overlay = cv2.warpAffine(overlay, rotate_M, (overlay.shape[1], overlay.shape[0]))
|
172 |
+
|
173 |
+
overlay = center_crop(overlay, image_size[0], image_size[1])
|
174 |
+
overlay[overlay==0] = image[overlay==0]
|
175 |
+
overlay = overlay.astype(np.uint8)
|
176 |
+
cv2.addWeighted(overlay, alpha, output, 1-alpha, 0, output)
|
177 |
+
|
178 |
+
return Image.fromarray(output)
|
179 |
+
|
180 |
+
def place_random_diagonal_watermark(
|
181 |
+
pil_image,
|
182 |
+
text,
|
183 |
+
random_step_x=(0.25, 0.4),
|
184 |
+
random_step_y=(0.25, 0.4),
|
185 |
+
random_angle=(-60,60),
|
186 |
+
text_height_in_percent_range=(0.10, 0.18),
|
187 |
+
text_alpha_range=(0.18, 0.4),
|
188 |
+
fonts=CV2_FONTS,
|
189 |
+
font_thickness_range=(2, 6),
|
190 |
+
colors=[(255,255,255)]
|
191 |
+
):
|
192 |
+
w, h = pil_image.size
|
193 |
+
|
194 |
+
text_height = int(h*random_float(*text_height_in_percent_range))
|
195 |
+
|
196 |
+
font, font_scale, font_thickness = get_random_font_params(text, text_height, fonts, font_thickness_range)
|
197 |
+
|
198 |
+
text_width, _ = get_text_size(text, font, font_scale, font_thickness)
|
199 |
+
|
200 |
+
return place_text_checkerboard(
|
201 |
+
pil_image,
|
202 |
+
text,
|
203 |
+
color=random.choice(colors),
|
204 |
+
alpha=random_float(*text_alpha_range),
|
205 |
+
step_x=random_float(*random_step_x),
|
206 |
+
step_y=random_float(*random_step_y),
|
207 |
+
angle=random.randint(*random_angle),
|
208 |
+
thickness=font_thickness,
|
209 |
+
font=font,
|
210 |
+
font_scale=font_scale
|
211 |
+
)
|
wmdetection/models/__init__.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchvision import models, transforms
|
5 |
+
from huggingface_hub import hf_hub_url, hf_hub_download
|
6 |
+
|
7 |
+
from .convnext import ConvNeXt
|
8 |
+
from wmdetection.utils import FP16Module
|
9 |
+
|
10 |
+
|
11 |
+
def get_convnext_model(name):
|
12 |
+
if name == 'convnext-tiny':
|
13 |
+
model_ft = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
|
14 |
+
model_ft.head = nn.Sequential(
|
15 |
+
nn.Linear(in_features=768, out_features=512),
|
16 |
+
nn.GELU(),
|
17 |
+
nn.Linear(in_features=512, out_features=256),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(in_features=256, out_features=2),
|
20 |
+
)
|
21 |
+
|
22 |
+
detector_transforms = transforms.Compose([
|
23 |
+
transforms.Resize((256, 256)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
26 |
+
])
|
27 |
+
return model_ft, detector_transforms
|
28 |
+
|
29 |
+
|
30 |
+
def get_resnext_model(name):
|
31 |
+
if name == 'resnext50_32x4d-small':
|
32 |
+
model_ft = models.resnext50_32x4d(pretrained=False)
|
33 |
+
elif name == 'resnext101_32x8d-large':
|
34 |
+
model_ft = models.resnext101_32x8d(pretrained=False)
|
35 |
+
|
36 |
+
num_ftrs = model_ft.fc.in_features
|
37 |
+
model_ft.fc = nn.Linear(num_ftrs, 2)
|
38 |
+
|
39 |
+
detector_transforms = transforms.Compose([
|
40 |
+
transforms.Resize((320, 320)),
|
41 |
+
transforms.ToTensor(),
|
42 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
43 |
+
])
|
44 |
+
|
45 |
+
return model_ft, detector_transforms
|
46 |
+
|
47 |
+
|
48 |
+
def get_watermarks_detection_model(name, device='cpu', fp16=True, pretrained=True, cache_dir='/tmp/watermark-detection'):
|
49 |
+
assert name in MODELS, f"Unknown model name: {name}"
|
50 |
+
assert not (fp16 and name.startswith('convnext')), "Can`t use fp16 mode with convnext models"
|
51 |
+
config = MODELS[name]
|
52 |
+
|
53 |
+
model_ft, detector_transforms = config['constructor'](name)
|
54 |
+
|
55 |
+
if pretrained:
|
56 |
+
hf_hub_download(repo_id=config['repo_id'], filename=config['filename'],
|
57 |
+
cache_dir=cache_dir, force_filename=config['filename'])
|
58 |
+
weights = torch.load(os.path.join(cache_dir, config['filename']), device)
|
59 |
+
model_ft.load_state_dict(weights)
|
60 |
+
|
61 |
+
if fp16:
|
62 |
+
model_ft = FP16Module(model_ft)
|
63 |
+
|
64 |
+
model_ft.eval()
|
65 |
+
model_ft = model_ft.to(device)
|
66 |
+
|
67 |
+
return model_ft, detector_transforms
|
68 |
+
|
69 |
+
|
70 |
+
MODELS = {
|
71 |
+
'convnext-tiny': dict(
|
72 |
+
constructor=get_convnext_model,
|
73 |
+
repo_id='boomb0om/watermark-detectors',
|
74 |
+
filename='convnext-tiny_watermarks_detector.pth',
|
75 |
+
),
|
76 |
+
'resnext101_32x8d-large': dict(
|
77 |
+
constructor=get_resnext_model,
|
78 |
+
repo_id='boomb0om/watermark-detectors',
|
79 |
+
filename='watermark_classifier-resnext101_32x8d-input_size320-4epochs_c097_w082.pth',
|
80 |
+
),
|
81 |
+
'resnext50_32x4d-small': dict(
|
82 |
+
constructor=get_resnext_model,
|
83 |
+
repo_id='boomb0om/watermark-detectors',
|
84 |
+
filename='watermark_classifier-resnext50_32x4d-input_size320-4epochs_c082_w078.pth',
|
85 |
+
)
|
86 |
+
}
|
wmdetection/models/convnext.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
13 |
+
from timm.models.registry import register_model
|
14 |
+
|
15 |
+
|
16 |
+
class Block(nn.Module):
|
17 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
18 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
19 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
20 |
+
We use (2) as we find it slightly faster in PyTorch
|
21 |
+
|
22 |
+
Args:
|
23 |
+
dim (int): Number of input channels.
|
24 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
25 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
26 |
+
"""
|
27 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
28 |
+
super().__init__()
|
29 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
30 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
31 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
32 |
+
self.act = nn.GELU()
|
33 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
34 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
35 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
36 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
input = x
|
40 |
+
x = self.dwconv(x)
|
41 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
42 |
+
x = self.norm(x)
|
43 |
+
x = self.pwconv1(x)
|
44 |
+
x = self.act(x)
|
45 |
+
x = self.pwconv2(x)
|
46 |
+
if self.gamma is not None:
|
47 |
+
x = self.gamma * x
|
48 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
49 |
+
|
50 |
+
x = input + self.drop_path(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
class ConvNeXt(nn.Module):
|
54 |
+
r""" ConvNeXt
|
55 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
56 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
57 |
+
|
58 |
+
Args:
|
59 |
+
in_chans (int): Number of input image channels. Default: 3
|
60 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
61 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
62 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
63 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
64 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
65 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
66 |
+
"""
|
67 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
68 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
69 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.dims = dims
|
74 |
+
|
75 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
76 |
+
stem = nn.Sequential(
|
77 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
78 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
79 |
+
)
|
80 |
+
self.downsample_layers.append(stem)
|
81 |
+
for i in range(3):
|
82 |
+
downsample_layer = nn.Sequential(
|
83 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
84 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
85 |
+
)
|
86 |
+
self.downsample_layers.append(downsample_layer)
|
87 |
+
|
88 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
89 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
90 |
+
cur = 0
|
91 |
+
for i in range(4):
|
92 |
+
stage = nn.Sequential(
|
93 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
94 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
95 |
+
)
|
96 |
+
self.stages.append(stage)
|
97 |
+
cur += depths[i]
|
98 |
+
|
99 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
100 |
+
self.head = nn.Linear(dims[-1], num_classes)
|
101 |
+
|
102 |
+
self.apply(self._init_weights)
|
103 |
+
self.head.weight.data.mul_(head_init_scale)
|
104 |
+
self.head.bias.data.mul_(head_init_scale)
|
105 |
+
|
106 |
+
def _init_weights(self, m):
|
107 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
108 |
+
trunc_normal_(m.weight, std=.02)
|
109 |
+
nn.init.constant_(m.bias, 0)
|
110 |
+
|
111 |
+
def forward_features(self, x):
|
112 |
+
for i in range(4):
|
113 |
+
x = self.downsample_layers[i](x)
|
114 |
+
x = self.stages[i](x)
|
115 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
x = self.forward_features(x)
|
119 |
+
x = self.head(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
class LayerNorm(nn.Module):
|
123 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
124 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
125 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
126 |
+
with shape (batch_size, channels, height, width).
|
127 |
+
"""
|
128 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
129 |
+
super().__init__()
|
130 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
131 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
132 |
+
self.eps = eps
|
133 |
+
self.data_format = data_format
|
134 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
135 |
+
raise NotImplementedError
|
136 |
+
self.normalized_shape = (normalized_shape, )
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
if self.data_format == "channels_last":
|
140 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
141 |
+
elif self.data_format == "channels_first":
|
142 |
+
u = x.mean(1, keepdim=True)
|
143 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
144 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
145 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
model_urls = {
|
150 |
+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
151 |
+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
152 |
+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
153 |
+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
154 |
+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
155 |
+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
156 |
+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
157 |
+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
158 |
+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
159 |
+
}
|
160 |
+
|
161 |
+
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
|
162 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
163 |
+
if pretrained:
|
164 |
+
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
|
165 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
166 |
+
model.load_state_dict(checkpoint["model"])
|
167 |
+
return model
|
168 |
+
|
169 |
+
def convnext_small(pretrained=False, in_22k=False, **kwargs):
|
170 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
171 |
+
if pretrained:
|
172 |
+
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
|
173 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
174 |
+
model.load_state_dict(checkpoint["model"])
|
175 |
+
return model
|
176 |
+
|
177 |
+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
|
178 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
179 |
+
if pretrained:
|
180 |
+
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
|
181 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
182 |
+
model.load_state_dict(checkpoint["model"])
|
183 |
+
return model
|
184 |
+
|
185 |
+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
|
186 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
187 |
+
if pretrained:
|
188 |
+
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
|
189 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
190 |
+
model.load_state_dict(checkpoint["model"])
|
191 |
+
return model
|
192 |
+
|
193 |
+
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
|
194 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
195 |
+
if pretrained:
|
196 |
+
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
|
197 |
+
url = model_urls['convnext_xlarge_22k']
|
198 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
199 |
+
model.load_state_dict(checkpoint["model"])
|
200 |
+
return model
|
wmdetection/pipelines/__init__.py
ADDED
File without changes
|
wmdetection/pipelines/metrics.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
4 |
+
|
5 |
+
|
6 |
+
def plot_confusion_matrix(x: np.ndarray, y: np.ndarray):
|
7 |
+
cm = confusion_matrix(x, y)
|
8 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['clean', 'watermark'])
|
9 |
+
return disp.plot()
|
wmdetection/pipelines/predictor.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torch.utils.data import BatchSampler, DataLoader
|
10 |
+
|
11 |
+
from wmdetection.utils import read_image_rgb
|
12 |
+
|
13 |
+
|
14 |
+
class ImageDataset(Dataset):
|
15 |
+
|
16 |
+
def __init__(self, objects, classifier_transforms):
|
17 |
+
self.objects = objects
|
18 |
+
self.classifier_transforms = classifier_transforms
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.objects)
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
obj = self.objects[idx]
|
25 |
+
assert isinstance(obj, (str, np.ndarray, Image.Image))
|
26 |
+
|
27 |
+
if isinstance(obj, str):
|
28 |
+
pil_img = read_image_rgb(obj)
|
29 |
+
elif isinstance(obj, np.ndarray):
|
30 |
+
pil_img = Image.fromarray(obj)
|
31 |
+
elif isinstance(obj, Image.Image):
|
32 |
+
pil_img = obj
|
33 |
+
|
34 |
+
resnet_img = self.classifier_transforms(pil_img).float()
|
35 |
+
|
36 |
+
return resnet_img
|
37 |
+
|
38 |
+
|
39 |
+
class WatermarksPredictor:
|
40 |
+
|
41 |
+
def __init__(self, wm_model, classifier_transforms, device):
|
42 |
+
self.wm_model = wm_model
|
43 |
+
self.wm_model.eval()
|
44 |
+
self.classifier_transforms = classifier_transforms
|
45 |
+
|
46 |
+
self.device = device
|
47 |
+
|
48 |
+
def predict_image(self, pil_image):
|
49 |
+
pil_image = pil_image.convert("RGB")
|
50 |
+
input_img = self.classifier_transforms(pil_image).float().unsqueeze(0)
|
51 |
+
outputs = self.wm_model(input_img.to(self.device))
|
52 |
+
result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
|
53 |
+
return result
|
54 |
+
|
55 |
+
def run(self, files, num_workers=8, bs=8, pbar=True):
|
56 |
+
eval_dataset = ImageDataset(files, self.classifier_transforms)
|
57 |
+
loader = DataLoader(
|
58 |
+
eval_dataset,
|
59 |
+
sampler=torch.utils.data.SequentialSampler(eval_dataset),
|
60 |
+
batch_size=bs,
|
61 |
+
drop_last=False,
|
62 |
+
num_workers=num_workers
|
63 |
+
)
|
64 |
+
if pbar:
|
65 |
+
loader = tqdm(loader)
|
66 |
+
|
67 |
+
result = []
|
68 |
+
for batch in loader:
|
69 |
+
with torch.no_grad():
|
70 |
+
outputs = self.wm_model(batch.to(self.device))
|
71 |
+
result.extend(torch.max(outputs, 1)[1].cpu().reshape(-1).tolist())
|
72 |
+
|
73 |
+
return result
|
wmdetection/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .files import get_extenstion, listdir_rec, list_images, read_image_rgb
|
2 |
+
from .fp16module import FP16Module
|
wmdetection/utils/files.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
IMAGE_EXT = set(['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'])
|
5 |
+
|
6 |
+
def get_extenstion(filepath):
|
7 |
+
return os.path.splitext(filepath)[-1]
|
8 |
+
|
9 |
+
def listdir_rec(folder_path):
|
10 |
+
filepaths = []
|
11 |
+
for root, dirname, files in os.walk(folder_path):
|
12 |
+
for file in files:
|
13 |
+
filepaths.append(os.path.join(root, file))
|
14 |
+
return filepaths
|
15 |
+
|
16 |
+
def list_images(folder_path):
|
17 |
+
files = listdir_rec(folder_path)
|
18 |
+
return [f for f in files if get_extenstion(f) in IMAGE_EXT]
|
19 |
+
|
20 |
+
def read_image_rgb(path):
|
21 |
+
pil_img = Image.open(path)
|
22 |
+
pil_img.load()
|
23 |
+
if pil_img.format is 'PNG' and pil_img.mode is not 'RGBA':
|
24 |
+
pil_img = pil_img.convert('RGBA')
|
25 |
+
pil_img = pil_img.convert('RGB')
|
26 |
+
return pil_img
|
wmdetection/utils/fp16module.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
|
7 |
+
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
|
8 |
+
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
|
9 |
+
|
10 |
+
|
11 |
+
def conversion_helper(val, conversion):
|
12 |
+
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
|
13 |
+
if not isinstance(val, (tuple, list)):
|
14 |
+
return conversion(val)
|
15 |
+
rtn = [conversion_helper(v, conversion) for v in val]
|
16 |
+
if isinstance(val, tuple):
|
17 |
+
rtn = tuple(rtn)
|
18 |
+
return rtn
|
19 |
+
|
20 |
+
|
21 |
+
def fp32_to_fp16(val):
|
22 |
+
"""Convert fp32 `val` to fp16"""
|
23 |
+
def half_conversion(val):
|
24 |
+
val_typecheck = val
|
25 |
+
if isinstance(val_typecheck, (Parameter, Variable)):
|
26 |
+
val_typecheck = val.data
|
27 |
+
if isinstance(val_typecheck, FLOAT_TYPES):
|
28 |
+
val = val.half()
|
29 |
+
return val
|
30 |
+
return conversion_helper(val, half_conversion)
|
31 |
+
|
32 |
+
|
33 |
+
def fp16_to_fp32(val):
|
34 |
+
"""Convert fp16 `val` to fp32"""
|
35 |
+
def float_conversion(val):
|
36 |
+
val_typecheck = val
|
37 |
+
if isinstance(val_typecheck, (Parameter, Variable)):
|
38 |
+
val_typecheck = val.data
|
39 |
+
if isinstance(val_typecheck, HALF_TYPES):
|
40 |
+
val = val.float()
|
41 |
+
return val
|
42 |
+
return conversion_helper(val, float_conversion)
|
43 |
+
|
44 |
+
|
45 |
+
class FP16Module(nn.Module):
|
46 |
+
def __init__(self, module):
|
47 |
+
super(FP16Module, self).__init__()
|
48 |
+
self.add_module('module', module.half())
|
49 |
+
|
50 |
+
def forward(self, *inputs, **kwargs):
|
51 |
+
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
|
52 |
+
|
53 |
+
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
54 |
+
return self.module.state_dict(destination, prefix, keep_vars)
|
55 |
+
|
56 |
+
def load_state_dict(self, state_dict, strict=True):
|
57 |
+
self.module.load_state_dict(state_dict, strict=strict)
|
58 |
+
|
59 |
+
def get_param(self, item):
|
60 |
+
return self.module.get_param(item)
|
61 |
+
|
62 |
+
def to(self, device, *args, **kwargs):
|
63 |
+
self.module.to(device)
|
64 |
+
return super().to(device, *args, **kwargs)
|