File size: 3,764 Bytes
1fcbc45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"from transformers import AutoModel\n",
"from torchvision.transforms.functional import to_pil_image, pil_to_tensor\n",
"from torchmetrics.classification import BinaryF1Score, BinaryAveragePrecision\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModel.from_pretrained(\"ductai199x/forensic-similarity-graph\", trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = model.eval().to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_paths = sorted(glob.glob(\"example_images/splicing-??.png\"))\n",
"gt_paths = sorted(glob.glob(\"example_images/splicing-??-gt.png\"))\n",
"image_vs_gt_paths = list(zip(image_paths, gt_paths))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" imgs = []\n",
" gts = []\n",
" img_preds = []\n",
" loc_preds = []\n",
" f1, mAP = BinaryF1Score(), BinaryAveragePrecision()\n",
" for image_path, gt_path in tqdm(image_vs_gt_paths):\n",
" image = pil_to_tensor(Image.open(image_path).convert(\"RGB\")).float() / 255\n",
" gt = ((pil_to_tensor(Image.open(gt_path).convert(\"L\")).float() / 255) < 0.9).int()\n",
" img_pred, loc_pred = model(image.unsqueeze(0).to(device))\n",
" img_pred, loc_pred = img_pred[0].cpu(), loc_pred[0].cpu()\n",
" f1.update(loc_pred[None, ...], gt)\n",
" mAP.update(loc_pred[None, ...], gt)\n",
" img_preds.append(img_pred)\n",
" loc_preds.append(loc_pred)\n",
" imgs.append(image)\n",
" gts.append(gt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f1.compute().item(), mAP.compute().item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"col = 4 * 2\n",
"row = -(-len(image_vs_gt_paths) // 4)\n",
"fig, axs = plt.subplots(row, col)\n",
"fig.set_size_inches(3 * col, 3 * row)\n",
"for i, (img, gt, img_pred, loc_pred) in enumerate(zip(imgs, gts, img_preds, loc_preds)):\n",
" ax = axs[i // 4][(i % 4) * 2]\n",
" ax.imshow(to_pil_image(img))\n",
" ax = axs[i // 4][(i % 4) * 2 + 1]\n",
" ax.imshow(to_pil_image(gt.float()))\n",
" ax.imshow(loc_pred, alpha=0.5, cmap=\"coolwarm\")\n",
"\n",
"for ax in axs.flat:\n",
" ax.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pyt_tf2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|