updated app
Browse files- .gitattributes +4 -0
- images/crab.png +0 -0
- images/fish.png +0 -0
- images/fish_2.png +3 -0
- images/fish_3.png +0 -0
- images/fish_4.png +0 -0
- images/fish_5.png +0 -0
- images/jelly_2.png +3 -0
- images/jelly_3.png +0 -0
- images/puff.png +0 -0
- images/red_fish_2.png +0 -0
- images/scene_2.png +3 -0
- images/scene_3.png +0 -0
- images/scene_4.png +0 -0
- images/scene_5.png +3 -0
- images/scene_6.png +0 -0
- images/soft_coral.png +0 -0
- images/starfish.png +0 -0
- images/starfish_2.png +0 -0
- inference.py +199 -0
.gitattributes
CHANGED
@@ -31,3 +31,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
+
images/scene_5.png filter=lfs diff=lfs merge=lfs -text
|
35 |
+
images/fish_2.png filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/jelly_2.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/scene_2.png filter=lfs diff=lfs merge=lfs -text
|
images/crab.png
ADDED
![]() |
images/fish.png
ADDED
![]() |
images/fish_2.png
ADDED
![]() |
Git LFS Details
|
images/fish_3.png
ADDED
![]() |
images/fish_4.png
ADDED
![]() |
images/fish_5.png
ADDED
![]() |
images/jelly_2.png
ADDED
![]() |
Git LFS Details
|
images/jelly_3.png
ADDED
![]() |
images/puff.png
ADDED
![]() |
images/red_fish_2.png
ADDED
![]() |
images/scene_2.png
ADDED
![]() |
Git LFS Details
|
images/scene_3.png
ADDED
![]() |
images/scene_4.png
ADDED
![]() |
images/scene_5.png
ADDED
![]() |
Git LFS Details
|
images/scene_6.png
ADDED
![]() |
images/soft_coral.png
ADDED
![]() |
images/starfish.png
ADDED
![]() |
images/starfish_2.png
ADDED
![]() |
inference.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import detectron2
|
6 |
+
import torchvision
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from detectron2 import model_zoo
|
11 |
+
from detectron2.data import Metadata
|
12 |
+
from detectron2.structures import BoxMode
|
13 |
+
from detectron2.utils.visualizer import Visualizer
|
14 |
+
from detectron2.config import get_cfg
|
15 |
+
from detectron2.utils.visualizer import ColorMode
|
16 |
+
from detectron2.modeling import build_model
|
17 |
+
import detectron2.data.transforms as T
|
18 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
19 |
+
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
|
22 |
+
# -----------------------------------------------------------------------------
|
23 |
+
# CONFIGS - loaded just the one time when script is first ran to save time.
|
24 |
+
#
|
25 |
+
# This is where you will set all the relevant config file and weight file
|
26 |
+
# variables:
|
27 |
+
# CONFIG_FILE - Training specific config file for fathomnet
|
28 |
+
# WEIGHTS_FILE - Path to the model with fathomnet weights
|
29 |
+
# NMS_THRESH - Set a nms threshold for the all boxes results
|
30 |
+
# SCORE_THRESH - This is where you can set the model score threshold
|
31 |
+
|
32 |
+
CONFIG_FILE = "fathomnet_config_v2_1280.yaml"
|
33 |
+
WEIGHTS_FILE = "model_final.pth"
|
34 |
+
NMS_THRESH = 0.45 #
|
35 |
+
SCORE_THRESH = 0.3 #
|
36 |
+
|
37 |
+
# A metadata object that contains metadata on each class category; used with
|
38 |
+
# Detectron for linking predictions to names and for visualizations.
|
39 |
+
fathomnet_metadata = Metadata(
|
40 |
+
name='fathomnet_val',
|
41 |
+
thing_classes=[
|
42 |
+
'Anemone',
|
43 |
+
'Fish',
|
44 |
+
'Eel',
|
45 |
+
'Gastropod',
|
46 |
+
'Sea star',
|
47 |
+
'Feather star',
|
48 |
+
'Sea cucumber',
|
49 |
+
'Urchin',
|
50 |
+
'Glass sponge',
|
51 |
+
'Sea fan',
|
52 |
+
'Soft coral',
|
53 |
+
'Sea pen',
|
54 |
+
'Stony coral',
|
55 |
+
'Ray',
|
56 |
+
'Crab',
|
57 |
+
'Shrimp',
|
58 |
+
'Squat lobster',
|
59 |
+
'Flatfish',
|
60 |
+
'Sea spider',
|
61 |
+
'Worm']
|
62 |
+
)
|
63 |
+
|
64 |
+
# This is where the model parameters are instantiated. There is a LOT of
|
65 |
+
# nested arguments in these yaml files, and the merging of baseline defaults
|
66 |
+
# plus dataset specific parameters.
|
67 |
+
base_model_path = "COCO-Detection/retinanet_R_50_FPN_3x.yaml"
|
68 |
+
|
69 |
+
cfg = get_cfg()
|
70 |
+
cfg.MODEL.DEVICE = 'cpu'
|
71 |
+
cfg.merge_from_file(model_zoo.get_config_file(base_model_path))
|
72 |
+
cfg.merge_from_file(CONFIG_FILE)
|
73 |
+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = SCORE_THRESH
|
74 |
+
cfg.MODEL.WEIGHTS = WEIGHTS_FILE
|
75 |
+
|
76 |
+
# Loading of the model weights, but more importantly this is where the model
|
77 |
+
# is actually instantiated as something that can take inputs and provide
|
78 |
+
# outputs. There is a lot of documentation about this, but not much in the
|
79 |
+
# way of straightforward tutorials.
|
80 |
+
model = build_model(cfg)
|
81 |
+
checkpointer = DetectionCheckpointer(model)
|
82 |
+
checkpointer.load(cfg.MODEL.WEIGHTS)
|
83 |
+
model.eval()
|
84 |
+
|
85 |
+
# Create two augmentations and make a list to iterate over
|
86 |
+
aug1 = T.ResizeShortestEdge(short_edge_length=[cfg.INPUT.MIN_SIZE_TEST],
|
87 |
+
max_size=cfg.INPUT.MAX_SIZE_TEST,
|
88 |
+
sample_style="choice")
|
89 |
+
|
90 |
+
aug2 = T.ResizeShortestEdge(short_edge_length=[1080],
|
91 |
+
max_size=1980,
|
92 |
+
sample_style="choice")
|
93 |
+
|
94 |
+
augmentations = [aug1, aug2]
|
95 |
+
|
96 |
+
# We use a separate NMS layer because initially detectron only does nms intra
|
97 |
+
# class, so we want to do nms on all boxes.
|
98 |
+
post_process_nms = torchvision.ops.nms
|
99 |
+
# -----------------------------------------------------------------------------
|
100 |
+
|
101 |
+
|
102 |
+
def run_inference(test_image):
|
103 |
+
"""This function runs through inference pipeline, taking in a single
|
104 |
+
image as input. The image will be opened, augmented, ran through the
|
105 |
+
model, which will output bounding boxes and class categories for each
|
106 |
+
object detected. These are then passed back to the calling function."""
|
107 |
+
|
108 |
+
# Load the image, get the height and width. Iterate over each
|
109 |
+
# augmentation: do the augmentation, run the model, perform nms
|
110 |
+
# thresholding, instantiate a useful object for visualizing the outputs.
|
111 |
+
# Saves a list of outputs objects
|
112 |
+
img = cv2.imread(test_image)
|
113 |
+
im_height, im_width, _ = img.shape
|
114 |
+
v_inf = Visualizer(img[:, :, ::-1],
|
115 |
+
metadata=fathomnet_metadata,
|
116 |
+
scale=1.0,
|
117 |
+
instance_mode=ColorMode.IMAGE_BW)
|
118 |
+
|
119 |
+
insts = []
|
120 |
+
|
121 |
+
# iterate over input augmentations (apply resizing)
|
122 |
+
for augmentation in augmentations:
|
123 |
+
im = augmentation.get_transform(img).apply_image(img)
|
124 |
+
|
125 |
+
# pre-process image by reshaping and converting to tensor
|
126 |
+
# pass to model, which outputs a dict containing info on all detections
|
127 |
+
with torch.no_grad():
|
128 |
+
im = torch.as_tensor(im.astype("float32").transpose(2, 0, 1))
|
129 |
+
model_outputs = model([{"image": im,
|
130 |
+
"height": im_height,
|
131 |
+
"width": im_width}])[0]
|
132 |
+
|
133 |
+
# populate list with all outputs
|
134 |
+
for _ in range(len(model_outputs['instances'])):
|
135 |
+
insts.append(model_outputs['instances'][_])
|
136 |
+
|
137 |
+
# Concatenate the model outputs and run NMS thresholding on all output;
|
138 |
+
# instantiate a dummy Instance object to concatenate the instances
|
139 |
+
model_inst = detectron2.structures.instances.Instances([im_height,
|
140 |
+
im_width])
|
141 |
+
|
142 |
+
xx = model_inst.cat(insts)[
|
143 |
+
post_process_nms(model_inst.cat(insts).pred_boxes.tensor,
|
144 |
+
model_inst.cat(insts).scores,
|
145 |
+
NMS_THRESH).to("cpu").tolist()]
|
146 |
+
|
147 |
+
print(test_image + ' - Number of predictions:', len(xx))
|
148 |
+
out_inf_raw = v_inf.draw_instance_predictions(xx.to("cpu"))
|
149 |
+
out_pil = Image.fromarray(out_inf_raw.get_image()).convert('RGB')
|
150 |
+
|
151 |
+
# Converting the predictions as output by Detectron2, to a TATOR format.
|
152 |
+
predictions = convert_predictions(xx, v_inf.metadata.thing_classes)
|
153 |
+
|
154 |
+
return predictions, out_pil
|
155 |
+
|
156 |
+
|
157 |
+
def convert_predictions(xx, thing_classes):
|
158 |
+
"""Helper funtion to post-process the predictions made by Detectron2
|
159 |
+
codebase to work with TATOR input requirements."""
|
160 |
+
|
161 |
+
predictions = []
|
162 |
+
|
163 |
+
for _ in range(len(xx)):
|
164 |
+
|
165 |
+
# Obtain the first prediction, instance
|
166 |
+
instance = xx.__getitem__(_)
|
167 |
+
|
168 |
+
# Map the coordinates to the variables
|
169 |
+
x, y, x2, y2 = map(float, instance.pred_boxes.tensor[0])
|
170 |
+
w, h = x2 - x, y2 - y
|
171 |
+
|
172 |
+
# Use class list to get the common name (string); get confidence score.
|
173 |
+
class_category = thing_classes[int(instance.pred_classes[0])]
|
174 |
+
confidence_score = float(instance.scores[0])
|
175 |
+
|
176 |
+
# Create a spec dict for TATOR
|
177 |
+
prediction = {'x': x,
|
178 |
+
'y': y,
|
179 |
+
'width': w,
|
180 |
+
'height': h,
|
181 |
+
'class_category': class_category,
|
182 |
+
'confidence': confidence_score}
|
183 |
+
|
184 |
+
predictions.append(prediction)
|
185 |
+
|
186 |
+
return predictions
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
|
191 |
+
# For demo purposes: run through a couple of test
|
192 |
+
# images and then output the predictions in a folder.
|
193 |
+
test_images = glob.glob("images/*.png")
|
194 |
+
|
195 |
+
for test_image in test_images:
|
196 |
+
predictions, out_img = run_inference(test_image)
|
197 |
+
|
198 |
+
print("Done.")
|
199 |
+
|