Spaces:
Runtime error
Runtime error
baixintech_zhangyiming_prod
commited on
Commit
·
53a3db7
1
Parent(s):
ad6f6d7
output with softmax
Browse files- app.py +8 -5
- wmdetection/pipelines/predictor.py +6 -0
app.py
CHANGED
@@ -12,13 +12,16 @@ model, transforms = get_watermarks_detection_model(
|
|
12 |
predictor = WatermarksPredictor(model, transforms, 'cpu')
|
13 |
|
14 |
|
15 |
-
def predict(image):
|
16 |
-
result = predictor.
|
17 |
-
|
|
|
|
|
18 |
|
19 |
|
20 |
examples = glob.glob(os.path.join('images', 'clean', '*'))
|
21 |
examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
|
22 |
-
|
23 |
-
|
|
|
24 |
iface.launch()
|
|
|
12 |
predictor = WatermarksPredictor(model, transforms, 'cpu')
|
13 |
|
14 |
|
15 |
+
def predict(image, threshold=0.5):
|
16 |
+
result = predictor.predict_image_confidence(image)
|
17 |
+
values = result.tolist()
|
18 |
+
wm_flag = 1 if values[1] >= threshold else 0
|
19 |
+
return 'watermarked' if wm_flag else 'clean', "%.4f" % values[1] # prints "watermarked"
|
20 |
|
21 |
|
22 |
examples = glob.glob(os.path.join('images', 'clean', '*'))
|
23 |
examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
|
24 |
+
examples = [[e, 0.5] for e in examples]
|
25 |
+
iface = gr.Interface(fn=predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(label="threshold", default=0.5), ],
|
26 |
+
examples=examples, outputs=[gr.outputs.Textbox(label="class"), gr.outputs.Textbox(label="wm_confidence")])
|
27 |
iface.launch()
|
wmdetection/pipelines/predictor.py
CHANGED
@@ -51,6 +51,12 @@ class WatermarksPredictor:
|
|
51 |
outputs = self.wm_model(input_img.to(self.device))
|
52 |
result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
|
53 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def run(self, files, num_workers=8, bs=8, pbar=True):
|
56 |
eval_dataset = ImageDataset(files, self.classifier_transforms)
|
|
|
51 |
outputs = self.wm_model(input_img.to(self.device))
|
52 |
result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
|
53 |
return result
|
54 |
+
|
55 |
+
def predict_image_confidence(self, pil_image):
|
56 |
+
pil_image = pil_image.convert("RGB")
|
57 |
+
input_img = self.classifier_transforms(pil_image).float().unsqueeze(0)
|
58 |
+
outputs = self.wm_model(input_img.to(self.device))
|
59 |
+
return torch.nn.functional.softmax(outputs, dim=1).cpu().reshape(-1)
|
60 |
|
61 |
def run(self, files, num_workers=8, bs=8, pbar=True):
|
62 |
eval_dataset = ImageDataset(files, self.classifier_transforms)
|