Jordan Pierce commited on
Commit
9657406
·
2 Parent(s): 6300104 e076b2e

updated app

Browse files
.gitattributes CHANGED
@@ -31,3 +31,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ images/scene_5.png filter=lfs diff=lfs merge=lfs -text
35
+ images/fish_2.png filter=lfs diff=lfs merge=lfs -text
36
+ images/jelly_2.png filter=lfs diff=lfs merge=lfs -text
37
+ images/scene_2.png filter=lfs diff=lfs merge=lfs -text
images/crab.png ADDED
images/fish.png ADDED
images/fish_2.png ADDED

Git LFS Details

  • SHA256: 7a4c79f95a5e2175d34e3740fcac83147f17b34ebed938c35b6d59bc9a184019
  • Pointer size: 132 Bytes
  • Size of remote file: 4.44 MB
images/fish_3.png ADDED
images/fish_4.png ADDED
images/fish_5.png ADDED
images/jelly_2.png ADDED

Git LFS Details

  • SHA256: b90ce6a8a471781249b9f3c7567bd831ff64dab824dd5a714065628b2cdf41d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.91 MB
images/jelly_3.png ADDED
images/puff.png ADDED
images/red_fish_2.png ADDED
images/scene_2.png ADDED

Git LFS Details

  • SHA256: 5352f28fb997d2bca33c0b64a437994b87abb409cf36f543ce2009b20e71f985
  • Pointer size: 132 Bytes
  • Size of remote file: 3.51 MB
images/scene_3.png ADDED
images/scene_4.png ADDED
images/scene_5.png ADDED

Git LFS Details

  • SHA256: 30276e3c9b6eadfaa6538e52e0a814e20642e97cb4444e785bd4e433070ae24d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.1 MB
images/scene_6.png ADDED
images/soft_coral.png ADDED
images/starfish.png ADDED
images/starfish_2.png ADDED
inference.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+
4
+ import numpy as np
5
+ import detectron2
6
+ import torchvision
7
+ import cv2
8
+ import torch
9
+
10
+ from detectron2 import model_zoo
11
+ from detectron2.data import Metadata
12
+ from detectron2.structures import BoxMode
13
+ from detectron2.utils.visualizer import Visualizer
14
+ from detectron2.config import get_cfg
15
+ from detectron2.utils.visualizer import ColorMode
16
+ from detectron2.modeling import build_model
17
+ import detectron2.data.transforms as T
18
+ from detectron2.checkpoint import DetectionCheckpointer
19
+
20
+ import matplotlib.pyplot as plt
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # CONFIGS - loaded just the one time when script is first ran to save time.
24
+ #
25
+ # This is where you will set all the relevant config file and weight file
26
+ # variables:
27
+ # CONFIG_FILE - Training specific config file for fathomnet
28
+ # WEIGHTS_FILE - Path to the model with fathomnet weights
29
+ # NMS_THRESH - Set a nms threshold for the all boxes results
30
+ # SCORE_THRESH - This is where you can set the model score threshold
31
+
32
+ CONFIG_FILE = "fathomnet_config_v2_1280.yaml"
33
+ WEIGHTS_FILE = "model_final.pth"
34
+ NMS_THRESH = 0.45 #
35
+ SCORE_THRESH = 0.3 #
36
+
37
+ # A metadata object that contains metadata on each class category; used with
38
+ # Detectron for linking predictions to names and for visualizations.
39
+ fathomnet_metadata = Metadata(
40
+ name='fathomnet_val',
41
+ thing_classes=[
42
+ 'Anemone',
43
+ 'Fish',
44
+ 'Eel',
45
+ 'Gastropod',
46
+ 'Sea star',
47
+ 'Feather star',
48
+ 'Sea cucumber',
49
+ 'Urchin',
50
+ 'Glass sponge',
51
+ 'Sea fan',
52
+ 'Soft coral',
53
+ 'Sea pen',
54
+ 'Stony coral',
55
+ 'Ray',
56
+ 'Crab',
57
+ 'Shrimp',
58
+ 'Squat lobster',
59
+ 'Flatfish',
60
+ 'Sea spider',
61
+ 'Worm']
62
+ )
63
+
64
+ # This is where the model parameters are instantiated. There is a LOT of
65
+ # nested arguments in these yaml files, and the merging of baseline defaults
66
+ # plus dataset specific parameters.
67
+ base_model_path = "COCO-Detection/retinanet_R_50_FPN_3x.yaml"
68
+
69
+ cfg = get_cfg()
70
+ cfg.MODEL.DEVICE = 'cpu'
71
+ cfg.merge_from_file(model_zoo.get_config_file(base_model_path))
72
+ cfg.merge_from_file(CONFIG_FILE)
73
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = SCORE_THRESH
74
+ cfg.MODEL.WEIGHTS = WEIGHTS_FILE
75
+
76
+ # Loading of the model weights, but more importantly this is where the model
77
+ # is actually instantiated as something that can take inputs and provide
78
+ # outputs. There is a lot of documentation about this, but not much in the
79
+ # way of straightforward tutorials.
80
+ model = build_model(cfg)
81
+ checkpointer = DetectionCheckpointer(model)
82
+ checkpointer.load(cfg.MODEL.WEIGHTS)
83
+ model.eval()
84
+
85
+ # Create two augmentations and make a list to iterate over
86
+ aug1 = T.ResizeShortestEdge(short_edge_length=[cfg.INPUT.MIN_SIZE_TEST],
87
+ max_size=cfg.INPUT.MAX_SIZE_TEST,
88
+ sample_style="choice")
89
+
90
+ aug2 = T.ResizeShortestEdge(short_edge_length=[1080],
91
+ max_size=1980,
92
+ sample_style="choice")
93
+
94
+ augmentations = [aug1, aug2]
95
+
96
+ # We use a separate NMS layer because initially detectron only does nms intra
97
+ # class, so we want to do nms on all boxes.
98
+ post_process_nms = torchvision.ops.nms
99
+ # -----------------------------------------------------------------------------
100
+
101
+
102
+ def run_inference(test_image):
103
+ """This function runs through inference pipeline, taking in a single
104
+ image as input. The image will be opened, augmented, ran through the
105
+ model, which will output bounding boxes and class categories for each
106
+ object detected. These are then passed back to the calling function."""
107
+
108
+ # Load the image, get the height and width. Iterate over each
109
+ # augmentation: do the augmentation, run the model, perform nms
110
+ # thresholding, instantiate a useful object for visualizing the outputs.
111
+ # Saves a list of outputs objects
112
+ img = cv2.imread(test_image)
113
+ im_height, im_width, _ = img.shape
114
+ v_inf = Visualizer(img[:, :, ::-1],
115
+ metadata=fathomnet_metadata,
116
+ scale=1.0,
117
+ instance_mode=ColorMode.IMAGE_BW)
118
+
119
+ insts = []
120
+
121
+ # iterate over input augmentations (apply resizing)
122
+ for augmentation in augmentations:
123
+ im = augmentation.get_transform(img).apply_image(img)
124
+
125
+ # pre-process image by reshaping and converting to tensor
126
+ # pass to model, which outputs a dict containing info on all detections
127
+ with torch.no_grad():
128
+ im = torch.as_tensor(im.astype("float32").transpose(2, 0, 1))
129
+ model_outputs = model([{"image": im,
130
+ "height": im_height,
131
+ "width": im_width}])[0]
132
+
133
+ # populate list with all outputs
134
+ for _ in range(len(model_outputs['instances'])):
135
+ insts.append(model_outputs['instances'][_])
136
+
137
+ # Concatenate the model outputs and run NMS thresholding on all output;
138
+ # instantiate a dummy Instance object to concatenate the instances
139
+ model_inst = detectron2.structures.instances.Instances([im_height,
140
+ im_width])
141
+
142
+ xx = model_inst.cat(insts)[
143
+ post_process_nms(model_inst.cat(insts).pred_boxes.tensor,
144
+ model_inst.cat(insts).scores,
145
+ NMS_THRESH).to("cpu").tolist()]
146
+
147
+ print(test_image + ' - Number of predictions:', len(xx))
148
+ out_inf_raw = v_inf.draw_instance_predictions(xx.to("cpu"))
149
+ out_pil = Image.fromarray(out_inf_raw.get_image()).convert('RGB')
150
+
151
+ # Converting the predictions as output by Detectron2, to a TATOR format.
152
+ predictions = convert_predictions(xx, v_inf.metadata.thing_classes)
153
+
154
+ return predictions, out_pil
155
+
156
+
157
+ def convert_predictions(xx, thing_classes):
158
+ """Helper funtion to post-process the predictions made by Detectron2
159
+ codebase to work with TATOR input requirements."""
160
+
161
+ predictions = []
162
+
163
+ for _ in range(len(xx)):
164
+
165
+ # Obtain the first prediction, instance
166
+ instance = xx.__getitem__(_)
167
+
168
+ # Map the coordinates to the variables
169
+ x, y, x2, y2 = map(float, instance.pred_boxes.tensor[0])
170
+ w, h = x2 - x, y2 - y
171
+
172
+ # Use class list to get the common name (string); get confidence score.
173
+ class_category = thing_classes[int(instance.pred_classes[0])]
174
+ confidence_score = float(instance.scores[0])
175
+
176
+ # Create a spec dict for TATOR
177
+ prediction = {'x': x,
178
+ 'y': y,
179
+ 'width': w,
180
+ 'height': h,
181
+ 'class_category': class_category,
182
+ 'confidence': confidence_score}
183
+
184
+ predictions.append(prediction)
185
+
186
+ return predictions
187
+
188
+
189
+ if __name__ == "__main__":
190
+
191
+ # For demo purposes: run through a couple of test
192
+ # images and then output the predictions in a folder.
193
+ test_images = glob.glob("images/*.png")
194
+
195
+ for test_image in test_images:
196
+ predictions, out_img = run_inference(test_image)
197
+
198
+ print("Done.")
199
+