Spaces:
Build error
Build error
innat
commited on
Commit
·
5637560
1
Parent(s):
238391b
upload
Browse files- LICENSE +23 -0
- README.md +5 -4
- app.py +135 -0
- components.py +0 -0
- config.py +34 -0
- examples/2fd875eaa.jpg +0 -0
- examples/348a992bb.jpg +0 -0
- examples/51b3e36ab.jpg +0 -0
- examples/51f1be19e.jpg +0 -0
- examples/53f253011.jpg +0 -0
- examples/796707dd7.jpg +0 -0
- examples/aac893a91.jpg +0 -0
- examples/cb8d261a3.jpg +0 -0
- examples/cc3532ff6.jpg +0 -0
- examples/f5a1f0358.jpg +0 -0
- model.h5 +3 -0
- mrcnn/__init__.py +1 -0
- mrcnn/config.py +239 -0
- mrcnn/model.py +0 -0
- mrcnn/parallel_model.py +188 -0
- mrcnn/utils.py +984 -0
- mrcnn/visualize.py +624 -0
- requirements.txt +15 -0
- setup.cfg +4 -0
- setup.py +69 -0
- utils.py +13 -0
LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Mask R-CNN
|
| 2 |
+
|
| 3 |
+
The MIT License (MIT)
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2017 Matterport, Inc.
|
| 6 |
+
|
| 7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
in the Software without restriction, including without limitation the rights
|
| 10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
furnished to do so, subject to the following conditions:
|
| 13 |
+
|
| 14 |
+
The above copyright notice and this permission notice shall be included in
|
| 15 |
+
all copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 23 |
+
THE SOFTWARE.
|
README.md
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.0.20
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Wheat Detect Demo
|
| 3 |
+
emoji: 🌾
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.0.20
|
| 8 |
+
python_version: 3.7
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
---
|
app.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------ tackle some noisy warning
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def warn(*args, **kwargs):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
warnings.warn = warn
|
| 11 |
+
warnings.filterwarnings("ignore")
|
| 12 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
import gdown
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import numpy as np
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
import mrcnn.model as modellib
|
| 24 |
+
from config import WheatDetectorConfig
|
| 25 |
+
from config import WheatInferenceConfig
|
| 26 |
+
from mrcnn import utils
|
| 27 |
+
from mrcnn import visualize
|
| 28 |
+
from mrcnn.model import log
|
| 29 |
+
from utils import get_ax
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# for reproducibility
|
| 33 |
+
def seed_all(SEED):
|
| 34 |
+
random.seed(SEED)
|
| 35 |
+
np.random.seed(SEED)
|
| 36 |
+
os.environ["PYTHONHASHSEED"] = str(SEED)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
ORIG_SIZE = 1024
|
| 40 |
+
seed_all(42)
|
| 41 |
+
|
| 42 |
+
config = WheatDetectorConfig()
|
| 43 |
+
inference_config = WheatInferenceConfig()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_model_weight(model_id):
|
| 47 |
+
"""Get the trained weights."""
|
| 48 |
+
if not os.path.exists("model.h5"):
|
| 49 |
+
model_weight = gdown.download(id=model_id, quiet=False)
|
| 50 |
+
else:
|
| 51 |
+
model_weight = "model.h5"
|
| 52 |
+
return model_weight
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_model():
|
| 56 |
+
"""Get the model."""
|
| 57 |
+
model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir="./")
|
| 58 |
+
return model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_model(model_id):
|
| 62 |
+
"""Load trained model."""
|
| 63 |
+
weight = get_model_weight(model_id)
|
| 64 |
+
model = get_model()
|
| 65 |
+
model.load_weights(weight, by_name=True)
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def prepare_image(image):
|
| 70 |
+
"""Prepare incoming sample."""
|
| 71 |
+
image = image[:, :, ::-1]
|
| 72 |
+
resize_factor = ORIG_SIZE / config.IMAGE_SHAPE[0]
|
| 73 |
+
|
| 74 |
+
# If grayscale. Convert to RGB for consistency.
|
| 75 |
+
if len(image.shape) != 3 or image.shape[2] != 3:
|
| 76 |
+
image = np.stack((image,) * 3, -1)
|
| 77 |
+
|
| 78 |
+
resized_image, window, scale, padding, crop = utils.resize_image(
|
| 79 |
+
image,
|
| 80 |
+
min_dim=config.IMAGE_MIN_DIM,
|
| 81 |
+
min_scale=config.IMAGE_MIN_SCALE,
|
| 82 |
+
max_dim=config.IMAGE_MAX_DIM,
|
| 83 |
+
mode=config.IMAGE_RESIZE_MODE,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return resized_image
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def predict_fn(image):
|
| 90 |
+
|
| 91 |
+
image = prepare_image(image)
|
| 92 |
+
|
| 93 |
+
model = load_model(model_id="1k4_WGBAUJCPbkkHkvtscX2jufTqETNYd")
|
| 94 |
+
results = model.detect([image])
|
| 95 |
+
r = results[0]
|
| 96 |
+
class_names = ["Wheat"] * len(r["rois"])
|
| 97 |
+
|
| 98 |
+
image = visualize.display_instances(
|
| 99 |
+
image,
|
| 100 |
+
r["rois"],
|
| 101 |
+
r["masks"],
|
| 102 |
+
r["class_ids"],
|
| 103 |
+
class_names,
|
| 104 |
+
r["scores"],
|
| 105 |
+
ax=get_ax(),
|
| 106 |
+
title="Predictions",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return image[:, :, ::-1]
|
| 110 |
+
|
| 111 |
+
title="Global Wheat Detection with Mask-RCNN Model"
|
| 112 |
+
description="<strong>Model</strong>: Mask-RCNN. <strong>Backbone</strong>: ResNet-101. Trained on: <a href='https://www.kaggle.com/competitions/global-wheat-detection/overview'>Global Wheat Detection Dataset (Kaggle)</a>. </br>The code is written in <code>Keras (TensorFlow 1.14)</code>. One can run the full code on Kaggle: <a href='https://www.kaggle.com/code/ipythonx/keras-global-wheat-detection-with-mask-rcnn'>[Keras]:Global Wheat Detection with Mask-RCNN</a>"
|
| 113 |
+
article = "<p>The model received <strong>0.6449</strong> and <strong>0.5675</strong> mAP (0.5:0.75:0.05) on the public and private test dataset respectively. The above examples are from test dataset without ground truth bounding box. Details: <a href='https://www.kaggle.com/competitions/global-wheat-detection/data'>Global Wheat Dataset</a></p>"
|
| 114 |
+
|
| 115 |
+
iface = gr.Interface(
|
| 116 |
+
fn=predict_fn,
|
| 117 |
+
inputs=gr.inputs.Image(label="Input Image"),
|
| 118 |
+
outputs=gr.outputs.Image(label="Prediction"),
|
| 119 |
+
title=title,
|
| 120 |
+
description=description,
|
| 121 |
+
article=article,
|
| 122 |
+
examples=[
|
| 123 |
+
["examples/2fd875eaa.jpg"],
|
| 124 |
+
["examples/51b3e36ab.jpg"],
|
| 125 |
+
["examples/51f1be19e.jpg"],
|
| 126 |
+
["examples/53f253011.jpg"],
|
| 127 |
+
["examples/348a992bb.jpg"],
|
| 128 |
+
["examples/796707dd7.jpg"],
|
| 129 |
+
["examples/aac893a91.jpg"],
|
| 130 |
+
["examples/cb8d261a3.jpg"],
|
| 131 |
+
["examples/cc3532ff6.jpg"],
|
| 132 |
+
["examples/f5a1f0358.jpg"],
|
| 133 |
+
],
|
| 134 |
+
)
|
| 135 |
+
iface.launch()
|
components.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mrcnn.config import Config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class WheatDetectorConfig(Config):
|
| 5 |
+
# Give the configuration a recognizable name
|
| 6 |
+
NAME = "wheat"
|
| 7 |
+
GPU_COUNT = 1
|
| 8 |
+
IMAGES_PER_GPU = 2
|
| 9 |
+
BACKBONE = "resnet101"
|
| 10 |
+
NUM_CLASSES = 2
|
| 11 |
+
IMAGE_RESIZE_MODE = "square"
|
| 12 |
+
IMAGE_MIN_DIM = 1024
|
| 13 |
+
IMAGE_MAX_DIM = 1024
|
| 14 |
+
STEPS_PER_EPOCH = 120
|
| 15 |
+
BACKBONE_STRIDES = [4, 8, 16, 32, 64]
|
| 16 |
+
RPN_ANCHOR_SCALES = (16, 32, 64, 128, 256)
|
| 17 |
+
LEARNING_RATE = 0.005
|
| 18 |
+
WEIGHT_DECAY = 0.0005
|
| 19 |
+
TRAIN_ROIS_PER_IMAGE = 350
|
| 20 |
+
DETECTION_MIN_CONFIDENCE = 0.60
|
| 21 |
+
VALIDATION_STEPS = 60
|
| 22 |
+
MAX_GT_INSTANCES = 500
|
| 23 |
+
LOSS_WEIGHTS = {
|
| 24 |
+
"rpn_class_loss": 1.0,
|
| 25 |
+
"rpn_bbox_loss": 1.0,
|
| 26 |
+
"mrcnn_class_loss": 1.0,
|
| 27 |
+
"mrcnn_bbox_loss": 1.0,
|
| 28 |
+
"mrcnn_mask_loss": 1.0,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WheatInferenceConfig(WheatDetectorConfig):
|
| 33 |
+
GPU_COUNT = 1
|
| 34 |
+
IMAGES_PER_GPU = 1
|
examples/2fd875eaa.jpg
ADDED
|
examples/348a992bb.jpg
ADDED
|
examples/51b3e36ab.jpg
ADDED
|
examples/51f1be19e.jpg
ADDED
|
examples/53f253011.jpg
ADDED
|
examples/796707dd7.jpg
ADDED
|
examples/aac893a91.jpg
ADDED
|
examples/cb8d261a3.jpg
ADDED
|
examples/cc3532ff6.jpg
ADDED
|
examples/f5a1f0358.jpg
ADDED
|
model.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:637fb6450e1332ed6447088b8dc68a492c4a8d64782dabeaf6fc4819e3da03e3
|
| 3 |
+
size 255858144
|
mrcnn/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
mrcnn/config.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask R-CNN
|
| 3 |
+
Base Configurations class.
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2017 Matterport, Inc.
|
| 6 |
+
Licensed under the MIT License (see LICENSE for details)
|
| 7 |
+
Written by Waleed Abdulla
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
# Base Configuration Class
|
| 13 |
+
# Don't use this class directly. Instead, sub-class it and override
|
| 14 |
+
# the configurations you need to change.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Config(object):
|
| 18 |
+
"""Base configuration class. For custom configurations, create a
|
| 19 |
+
sub-class that inherits from this one and override properties
|
| 20 |
+
that need to be changed.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# Name the configurations. For example, 'COCO', 'Experiment 3', ...etc.
|
| 24 |
+
# Useful if your code needs to do things differently depending on which
|
| 25 |
+
# experiment is running.
|
| 26 |
+
NAME = None # Override in sub-classes
|
| 27 |
+
|
| 28 |
+
# NUMBER OF GPUs to use. When using only a CPU, this needs to be set to 1.
|
| 29 |
+
GPU_COUNT = 1
|
| 30 |
+
|
| 31 |
+
# Number of images to train with on each GPU. A 12GB GPU can typically
|
| 32 |
+
# handle 2 images of 1024x1024px.
|
| 33 |
+
# Adjust based on your GPU memory and image sizes. Use the highest
|
| 34 |
+
# number that your GPU can handle for best performance.
|
| 35 |
+
IMAGES_PER_GPU = 2
|
| 36 |
+
|
| 37 |
+
# Number of training steps per epoch
|
| 38 |
+
# This doesn't need to match the size of the training set. Tensorboard
|
| 39 |
+
# updates are saved at the end of each epoch, so setting this to a
|
| 40 |
+
# smaller number means getting more frequent TensorBoard updates.
|
| 41 |
+
# Validation stats are also calculated at each epoch end and they
|
| 42 |
+
# might take a while, so don't set this too small to avoid spending
|
| 43 |
+
# a lot of time on validation stats.
|
| 44 |
+
STEPS_PER_EPOCH = 1000
|
| 45 |
+
|
| 46 |
+
# Number of validation steps to run at the end of every training epoch.
|
| 47 |
+
# A bigger number improves accuracy of validation stats, but slows
|
| 48 |
+
# down the training.
|
| 49 |
+
VALIDATION_STEPS = 50
|
| 50 |
+
|
| 51 |
+
# Backbone network architecture
|
| 52 |
+
# Supported values are: resnet50, resnet101.
|
| 53 |
+
# You can also provide a callable that should have the signature
|
| 54 |
+
# of model.resnet_graph. If you do so, you need to supply a callable
|
| 55 |
+
# to COMPUTE_BACKBONE_SHAPE as well
|
| 56 |
+
BACKBONE = "resnet101"
|
| 57 |
+
|
| 58 |
+
# Only useful if you supply a callable to BACKBONE. Should compute
|
| 59 |
+
# the shape of each layer of the FPN Pyramid.
|
| 60 |
+
# See model.compute_backbone_shapes
|
| 61 |
+
COMPUTE_BACKBONE_SHAPE = None
|
| 62 |
+
|
| 63 |
+
# The strides of each layer of the FPN Pyramid. These values
|
| 64 |
+
# are based on a Resnet101 backbone.
|
| 65 |
+
BACKBONE_STRIDES = [4, 8, 16, 32, 64]
|
| 66 |
+
|
| 67 |
+
# Size of the fully-connected layers in the classification graph
|
| 68 |
+
FPN_CLASSIF_FC_LAYERS_SIZE = 1024
|
| 69 |
+
|
| 70 |
+
# Size of the top-down layers used to build the feature pyramid
|
| 71 |
+
TOP_DOWN_PYRAMID_SIZE = 256
|
| 72 |
+
|
| 73 |
+
# Number of classification classes (including background)
|
| 74 |
+
NUM_CLASSES = 1 # Override in sub-classes
|
| 75 |
+
|
| 76 |
+
# Length of square anchor side in pixels
|
| 77 |
+
RPN_ANCHOR_SCALES = (32, 64, 128, 256, 512)
|
| 78 |
+
|
| 79 |
+
# Ratios of anchors at each cell (width/height)
|
| 80 |
+
# A value of 1 represents a square anchor, and 0.5 is a wide anchor
|
| 81 |
+
RPN_ANCHOR_RATIOS = [0.5, 1, 2]
|
| 82 |
+
|
| 83 |
+
# Anchor stride
|
| 84 |
+
# If 1 then anchors are created for each cell in the backbone feature map.
|
| 85 |
+
# If 2, then anchors are created for every other cell, and so on.
|
| 86 |
+
RPN_ANCHOR_STRIDE = 1
|
| 87 |
+
|
| 88 |
+
# Non-max suppression threshold to filter RPN proposals.
|
| 89 |
+
# You can increase this during training to generate more propsals.
|
| 90 |
+
RPN_NMS_THRESHOLD = 0.7
|
| 91 |
+
|
| 92 |
+
# How many anchors per image to use for RPN training
|
| 93 |
+
RPN_TRAIN_ANCHORS_PER_IMAGE = 256
|
| 94 |
+
|
| 95 |
+
# ROIs kept after tf.nn.top_k and before non-maximum suppression
|
| 96 |
+
PRE_NMS_LIMIT = 6000
|
| 97 |
+
|
| 98 |
+
# ROIs kept after non-maximum suppression (training and inference)
|
| 99 |
+
POST_NMS_ROIS_TRAINING = 2000
|
| 100 |
+
POST_NMS_ROIS_INFERENCE = 1000
|
| 101 |
+
|
| 102 |
+
# If enabled, resizes instance masks to a smaller size to reduce
|
| 103 |
+
# memory load. Recommended when using high-resolution images.
|
| 104 |
+
USE_MINI_MASK = True
|
| 105 |
+
MINI_MASK_SHAPE = (56, 56) # (height, width) of the mini-mask
|
| 106 |
+
|
| 107 |
+
# Input image resizing
|
| 108 |
+
# Generally, use the "square" resizing mode for training and predicting
|
| 109 |
+
# and it should work well in most cases. In this mode, images are scaled
|
| 110 |
+
# up such that the small side is = IMAGE_MIN_DIM, but ensuring that the
|
| 111 |
+
# scaling doesn't make the long side > IMAGE_MAX_DIM. Then the image is
|
| 112 |
+
# padded with zeros to make it a square so multiple images can be put
|
| 113 |
+
# in one batch.
|
| 114 |
+
# Available resizing modes:
|
| 115 |
+
# none: No resizing or padding. Return the image unchanged.
|
| 116 |
+
# square: Resize and pad with zeros to get a square image
|
| 117 |
+
# of size [max_dim, max_dim].
|
| 118 |
+
# pad64: Pads width and height with zeros to make them multiples of 64.
|
| 119 |
+
# If IMAGE_MIN_DIM or IMAGE_MIN_SCALE are not None, then it scales
|
| 120 |
+
# up before padding. IMAGE_MAX_DIM is ignored in this mode.
|
| 121 |
+
# The multiple of 64 is needed to ensure smooth scaling of feature
|
| 122 |
+
# maps up and down the 6 levels of the FPN pyramid (2**6=64).
|
| 123 |
+
# crop: Picks random crops from the image. First, scales the image based
|
| 124 |
+
# on IMAGE_MIN_DIM and IMAGE_MIN_SCALE, then picks a random crop of
|
| 125 |
+
# size IMAGE_MIN_DIM x IMAGE_MIN_DIM. Can be used in training only.
|
| 126 |
+
# IMAGE_MAX_DIM is not used in this mode.
|
| 127 |
+
IMAGE_RESIZE_MODE = "square"
|
| 128 |
+
IMAGE_MIN_DIM = 800
|
| 129 |
+
IMAGE_MAX_DIM = 1024
|
| 130 |
+
# Minimum scaling ratio. Checked after MIN_IMAGE_DIM and can force further
|
| 131 |
+
# up scaling. For example, if set to 2 then images are scaled up to double
|
| 132 |
+
# the width and height, or more, even if MIN_IMAGE_DIM doesn't require it.
|
| 133 |
+
# However, in 'square' mode, it can be overruled by IMAGE_MAX_DIM.
|
| 134 |
+
IMAGE_MIN_SCALE = 0
|
| 135 |
+
# Number of color channels per image. RGB = 3, grayscale = 1, RGB-D = 4
|
| 136 |
+
# Changing this requires other changes in the code. See the WIKI for more
|
| 137 |
+
# details: https://github.com/matterport/Mask_RCNN/wiki
|
| 138 |
+
IMAGE_CHANNEL_COUNT = 3
|
| 139 |
+
|
| 140 |
+
# Image mean (RGB)
|
| 141 |
+
MEAN_PIXEL = np.array([123.7, 116.8, 103.9])
|
| 142 |
+
|
| 143 |
+
# Number of ROIs per image to feed to classifier/mask heads
|
| 144 |
+
# The Mask RCNN paper uses 512 but often the RPN doesn't generate
|
| 145 |
+
# enough positive proposals to fill this and keep a positive:negative
|
| 146 |
+
# ratio of 1:3. You can increase the number of proposals by adjusting
|
| 147 |
+
# the RPN NMS threshold.
|
| 148 |
+
TRAIN_ROIS_PER_IMAGE = 200
|
| 149 |
+
|
| 150 |
+
# Percent of positive ROIs used to train classifier/mask heads
|
| 151 |
+
ROI_POSITIVE_RATIO = 0.33
|
| 152 |
+
|
| 153 |
+
# Pooled ROIs
|
| 154 |
+
POOL_SIZE = 7
|
| 155 |
+
MASK_POOL_SIZE = 14
|
| 156 |
+
|
| 157 |
+
# Shape of output mask
|
| 158 |
+
# To change this you also need to change the neural network mask branch
|
| 159 |
+
MASK_SHAPE = [28, 28]
|
| 160 |
+
|
| 161 |
+
# Maximum number of ground truth instances to use in one image
|
| 162 |
+
MAX_GT_INSTANCES = 100
|
| 163 |
+
|
| 164 |
+
# Bounding box refinement standard deviation for RPN and final detections.
|
| 165 |
+
RPN_BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2])
|
| 166 |
+
BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2])
|
| 167 |
+
|
| 168 |
+
# Max number of final detections
|
| 169 |
+
DETECTION_MAX_INSTANCES = 100
|
| 170 |
+
|
| 171 |
+
# Minimum probability value to accept a detected instance
|
| 172 |
+
# ROIs below this threshold are skipped
|
| 173 |
+
DETECTION_MIN_CONFIDENCE = 0.7
|
| 174 |
+
|
| 175 |
+
# Non-maximum suppression threshold for detection
|
| 176 |
+
DETECTION_NMS_THRESHOLD = 0.3
|
| 177 |
+
|
| 178 |
+
# Learning rate and momentum
|
| 179 |
+
# The Mask RCNN paper uses lr=0.02, but on TensorFlow it causes
|
| 180 |
+
# weights to explode. Likely due to differences in optimizer
|
| 181 |
+
# implementation.
|
| 182 |
+
LEARNING_RATE = 0.001
|
| 183 |
+
LEARNING_MOMENTUM = 0.9
|
| 184 |
+
|
| 185 |
+
# Weight decay regularization
|
| 186 |
+
WEIGHT_DECAY = 0.0001
|
| 187 |
+
|
| 188 |
+
# Loss weights for more precise optimization.
|
| 189 |
+
# Can be used for R-CNN training setup.
|
| 190 |
+
LOSS_WEIGHTS = {
|
| 191 |
+
"rpn_class_loss": 1.0,
|
| 192 |
+
"rpn_bbox_loss": 1.0,
|
| 193 |
+
"mrcnn_class_loss": 1.0,
|
| 194 |
+
"mrcnn_bbox_loss": 1.0,
|
| 195 |
+
"mrcnn_mask_loss": 1.0,
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# Use RPN ROIs or externally generated ROIs for training
|
| 199 |
+
# Keep this True for most situations. Set to False if you want to train
|
| 200 |
+
# the head branches on ROI generated by code rather than the ROIs from
|
| 201 |
+
# the RPN. For example, to debug the classifier head without having to
|
| 202 |
+
# train the RPN.
|
| 203 |
+
USE_RPN_ROIS = True
|
| 204 |
+
|
| 205 |
+
# Train or freeze batch normalization layers
|
| 206 |
+
# None: Train BN layers. This is the normal mode
|
| 207 |
+
# False: Freeze BN layers. Good when using a small batch size
|
| 208 |
+
# True: (don't use). Set layer in training mode even when predicting
|
| 209 |
+
TRAIN_BN = False # Defaulting to False since batch size is often small
|
| 210 |
+
|
| 211 |
+
# Gradient norm clipping
|
| 212 |
+
GRADIENT_CLIP_NORM = 5.0
|
| 213 |
+
|
| 214 |
+
def __init__(self):
|
| 215 |
+
"""Set values of computed attributes."""
|
| 216 |
+
# Effective batch size
|
| 217 |
+
self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT
|
| 218 |
+
|
| 219 |
+
# Input image size
|
| 220 |
+
if self.IMAGE_RESIZE_MODE == "crop":
|
| 221 |
+
self.IMAGE_SHAPE = np.array(
|
| 222 |
+
[self.IMAGE_MIN_DIM, self.IMAGE_MIN_DIM, self.IMAGE_CHANNEL_COUNT]
|
| 223 |
+
)
|
| 224 |
+
else:
|
| 225 |
+
self.IMAGE_SHAPE = np.array(
|
| 226 |
+
[self.IMAGE_MAX_DIM, self.IMAGE_MAX_DIM, self.IMAGE_CHANNEL_COUNT]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Image meta data length
|
| 230 |
+
# See compose_image_meta() for details
|
| 231 |
+
self.IMAGE_META_SIZE = 1 + 3 + 3 + 4 + 1 + self.NUM_CLASSES
|
| 232 |
+
|
| 233 |
+
def display(self):
|
| 234 |
+
"""Display Configuration values."""
|
| 235 |
+
print("\nConfigurations:")
|
| 236 |
+
for a in dir(self):
|
| 237 |
+
if not a.startswith("__") and not callable(getattr(self, a)):
|
| 238 |
+
print("{:30} {}".format(a, getattr(self, a)))
|
| 239 |
+
print("\n")
|
mrcnn/model.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mrcnn/parallel_model.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask R-CNN
|
| 3 |
+
Multi-GPU Support for Keras.
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2017 Matterport, Inc.
|
| 6 |
+
Licensed under the MIT License (see LICENSE for details)
|
| 7 |
+
Written by Waleed Abdulla
|
| 8 |
+
|
| 9 |
+
Ideas and a small code snippets from these sources:
|
| 10 |
+
https://github.com/fchollet/keras/issues/2436
|
| 11 |
+
https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
|
| 12 |
+
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
|
| 13 |
+
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import keras.backend as K
|
| 17 |
+
import keras.layers as KL
|
| 18 |
+
import keras.models as KM
|
| 19 |
+
import tensorflow as tf
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ParallelModel(KM.Model):
|
| 23 |
+
"""Subclasses the standard Keras Model and adds multi-GPU support.
|
| 24 |
+
It works by creating a copy of the model on each GPU. Then it slices
|
| 25 |
+
the inputs and sends a slice to each copy of the model, and then
|
| 26 |
+
merges the outputs together and applies the loss on the combined
|
| 27 |
+
outputs.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, keras_model, gpu_count):
|
| 31 |
+
"""Class constructor.
|
| 32 |
+
keras_model: The Keras model to parallelize
|
| 33 |
+
gpu_count: Number of GPUs. Must be > 1
|
| 34 |
+
"""
|
| 35 |
+
self.inner_model = keras_model
|
| 36 |
+
self.gpu_count = gpu_count
|
| 37 |
+
merged_outputs = self.make_parallel()
|
| 38 |
+
super(ParallelModel, self).__init__(
|
| 39 |
+
inputs=self.inner_model.inputs, outputs=merged_outputs
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def __getattribute__(self, attrname):
|
| 43 |
+
"""Redirect loading and saving methods to the inner model. That's where
|
| 44 |
+
the weights are stored."""
|
| 45 |
+
if "load" in attrname or "save" in attrname:
|
| 46 |
+
return getattr(self.inner_model, attrname)
|
| 47 |
+
return super(ParallelModel, self).__getattribute__(attrname)
|
| 48 |
+
|
| 49 |
+
def summary(self, *args, **kwargs):
|
| 50 |
+
"""Override summary() to display summaries of both, the wrapper
|
| 51 |
+
and inner models."""
|
| 52 |
+
super(ParallelModel, self).summary(*args, **kwargs)
|
| 53 |
+
self.inner_model.summary(*args, **kwargs)
|
| 54 |
+
|
| 55 |
+
def make_parallel(self):
|
| 56 |
+
"""Creates a new wrapper model that consists of multiple replicas of
|
| 57 |
+
the original model placed on different GPUs.
|
| 58 |
+
"""
|
| 59 |
+
# Slice inputs. Slice inputs on the CPU to avoid sending a copy
|
| 60 |
+
# of the full inputs to all GPUs. Saves on bandwidth and memory.
|
| 61 |
+
input_slices = {
|
| 62 |
+
name: tf.split(x, self.gpu_count)
|
| 63 |
+
for name, x in zip(self.inner_model.input_names, self.inner_model.inputs)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
output_names = self.inner_model.output_names
|
| 67 |
+
outputs_all = []
|
| 68 |
+
for i in range(len(self.inner_model.outputs)):
|
| 69 |
+
outputs_all.append([])
|
| 70 |
+
|
| 71 |
+
# Run the model call() on each GPU to place the ops there
|
| 72 |
+
for i in range(self.gpu_count):
|
| 73 |
+
with tf.device("/gpu:%d" % i):
|
| 74 |
+
with tf.name_scope("tower_%d" % i):
|
| 75 |
+
# Run a slice of inputs through this replica
|
| 76 |
+
zipped_inputs = zip(
|
| 77 |
+
self.inner_model.input_names, self.inner_model.inputs
|
| 78 |
+
)
|
| 79 |
+
inputs = [
|
| 80 |
+
KL.Lambda(
|
| 81 |
+
lambda s: input_slices[name][i],
|
| 82 |
+
output_shape=lambda s: (None,) + s[1:],
|
| 83 |
+
)(tensor)
|
| 84 |
+
for name, tensor in zipped_inputs
|
| 85 |
+
]
|
| 86 |
+
# Create the model replica and get the outputs
|
| 87 |
+
outputs = self.inner_model(inputs)
|
| 88 |
+
if not isinstance(outputs, list):
|
| 89 |
+
outputs = [outputs]
|
| 90 |
+
# Save the outputs for merging back together later
|
| 91 |
+
for l, o in enumerate(outputs):
|
| 92 |
+
outputs_all[l].append(o)
|
| 93 |
+
|
| 94 |
+
# Merge outputs on CPU
|
| 95 |
+
with tf.device("/cpu:0"):
|
| 96 |
+
merged = []
|
| 97 |
+
for outputs, name in zip(outputs_all, output_names):
|
| 98 |
+
# Concatenate or average outputs?
|
| 99 |
+
# Outputs usually have a batch dimension and we concatenate
|
| 100 |
+
# across it. If they don't, then the output is likely a loss
|
| 101 |
+
# or a metric value that gets averaged across the batch.
|
| 102 |
+
# Keras expects losses and metrics to be scalars.
|
| 103 |
+
if K.int_shape(outputs[0]) == ():
|
| 104 |
+
# Average
|
| 105 |
+
m = KL.Lambda(lambda o: tf.add_n(o) / len(outputs), name=name)(
|
| 106 |
+
outputs
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
# Concatenate
|
| 110 |
+
m = KL.Concatenate(axis=0, name=name)(outputs)
|
| 111 |
+
merged.append(m)
|
| 112 |
+
return merged
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
# Testing code below. It creates a simple model to train on MNIST and
|
| 117 |
+
# tries to run it on 2 GPUs. It saves the graph so it can be viewed
|
| 118 |
+
# in TensorBoard. Run it as:
|
| 119 |
+
#
|
| 120 |
+
# python3 parallel_model.py
|
| 121 |
+
|
| 122 |
+
import os
|
| 123 |
+
|
| 124 |
+
import keras.optimizers
|
| 125 |
+
import numpy as np
|
| 126 |
+
from keras.datasets import mnist
|
| 127 |
+
from keras.preprocessing.image import ImageDataGenerator
|
| 128 |
+
|
| 129 |
+
GPU_COUNT = 2
|
| 130 |
+
|
| 131 |
+
# Root directory of the project
|
| 132 |
+
ROOT_DIR = os.path.abspath("../")
|
| 133 |
+
|
| 134 |
+
# Directory to save logs and trained model
|
| 135 |
+
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
|
| 136 |
+
|
| 137 |
+
def build_model(x_train, num_classes):
|
| 138 |
+
# Reset default graph. Keras leaves old ops in the graph,
|
| 139 |
+
# which are ignored for execution but clutter graph
|
| 140 |
+
# visualization in TensorBoard.
|
| 141 |
+
tf.reset_default_graph()
|
| 142 |
+
|
| 143 |
+
inputs = KL.Input(shape=x_train.shape[1:], name="input_image")
|
| 144 |
+
x = KL.Conv2D(32, (3, 3), activation="relu", padding="same", name="conv1")(
|
| 145 |
+
inputs
|
| 146 |
+
)
|
| 147 |
+
x = KL.Conv2D(64, (3, 3), activation="relu", padding="same", name="conv2")(x)
|
| 148 |
+
x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
|
| 149 |
+
x = KL.Flatten(name="flat1")(x)
|
| 150 |
+
x = KL.Dense(128, activation="relu", name="dense1")(x)
|
| 151 |
+
x = KL.Dense(num_classes, activation="softmax", name="dense2")(x)
|
| 152 |
+
|
| 153 |
+
return KM.Model(inputs, x, "digit_classifier_model")
|
| 154 |
+
|
| 155 |
+
# Load MNIST Data
|
| 156 |
+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
| 157 |
+
x_train = np.expand_dims(x_train, -1).astype("float32") / 255
|
| 158 |
+
x_test = np.expand_dims(x_test, -1).astype("float32") / 255
|
| 159 |
+
|
| 160 |
+
print("x_train shape:", x_train.shape)
|
| 161 |
+
print("x_test shape:", x_test.shape)
|
| 162 |
+
|
| 163 |
+
# Build data generator and model
|
| 164 |
+
datagen = ImageDataGenerator()
|
| 165 |
+
model = build_model(x_train, 10)
|
| 166 |
+
|
| 167 |
+
# Add multi-GPU support.
|
| 168 |
+
model = ParallelModel(model, GPU_COUNT)
|
| 169 |
+
|
| 170 |
+
optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0)
|
| 171 |
+
|
| 172 |
+
model.compile(
|
| 173 |
+
loss="sparse_categorical_crossentropy",
|
| 174 |
+
optimizer=optimizer,
|
| 175 |
+
metrics=["accuracy"],
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
model.summary()
|
| 179 |
+
|
| 180 |
+
# Train
|
| 181 |
+
model.fit_generator(
|
| 182 |
+
datagen.flow(x_train, y_train, batch_size=64),
|
| 183 |
+
steps_per_epoch=50,
|
| 184 |
+
epochs=10,
|
| 185 |
+
verbose=1,
|
| 186 |
+
validation_data=(x_test, y_test),
|
| 187 |
+
callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR, write_graph=True)],
|
| 188 |
+
)
|
mrcnn/utils.py
ADDED
|
@@ -0,0 +1,984 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask R-CNN
|
| 3 |
+
Common utility functions and classes.
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2017 Matterport, Inc.
|
| 6 |
+
Licensed under the MIT License (see LICENSE for details)
|
| 7 |
+
Written by Waleed Abdulla
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import shutil
|
| 15 |
+
import sys
|
| 16 |
+
import urllib.request
|
| 17 |
+
import warnings
|
| 18 |
+
from distutils.version import LooseVersion
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import scipy
|
| 22 |
+
import skimage.color
|
| 23 |
+
import skimage.io
|
| 24 |
+
import skimage.transform
|
| 25 |
+
import tensorflow as tf
|
| 26 |
+
|
| 27 |
+
# URL from which to download the latest COCO trained weights
|
| 28 |
+
COCO_MODEL_URL = (
|
| 29 |
+
"https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
############################################################
|
| 34 |
+
# Bounding Boxes
|
| 35 |
+
############################################################
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def extract_bboxes(mask):
|
| 39 |
+
"""Compute bounding boxes from masks.
|
| 40 |
+
mask: [height, width, num_instances]. Mask pixels are either 1 or 0.
|
| 41 |
+
|
| 42 |
+
Returns: bbox array [num_instances, (y1, x1, y2, x2)].
|
| 43 |
+
"""
|
| 44 |
+
boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
|
| 45 |
+
for i in range(mask.shape[-1]):
|
| 46 |
+
m = mask[:, :, i]
|
| 47 |
+
# Bounding box.
|
| 48 |
+
horizontal_indicies = np.where(np.any(m, axis=0))[0]
|
| 49 |
+
vertical_indicies = np.where(np.any(m, axis=1))[0]
|
| 50 |
+
if horizontal_indicies.shape[0]:
|
| 51 |
+
x1, x2 = horizontal_indicies[[0, -1]]
|
| 52 |
+
y1, y2 = vertical_indicies[[0, -1]]
|
| 53 |
+
# x2 and y2 should not be part of the box. Increment by 1.
|
| 54 |
+
x2 += 1
|
| 55 |
+
y2 += 1
|
| 56 |
+
else:
|
| 57 |
+
# No mask for this instance. Might happen due to
|
| 58 |
+
# resizing or cropping. Set bbox to zeros
|
| 59 |
+
x1, x2, y1, y2 = 0, 0, 0, 0
|
| 60 |
+
boxes[i] = np.array([y1, x1, y2, x2])
|
| 61 |
+
return boxes.astype(np.int32)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_iou(box, boxes, box_area, boxes_area):
|
| 65 |
+
"""Calculates IoU of the given box with the array of the given boxes.
|
| 66 |
+
box: 1D vector [y1, x1, y2, x2]
|
| 67 |
+
boxes: [boxes_count, (y1, x1, y2, x2)]
|
| 68 |
+
box_area: float. the area of 'box'
|
| 69 |
+
boxes_area: array of length boxes_count.
|
| 70 |
+
|
| 71 |
+
Note: the areas are passed in rather than calculated here for
|
| 72 |
+
efficiency. Calculate once in the caller to avoid duplicate work.
|
| 73 |
+
"""
|
| 74 |
+
# Calculate intersection areas
|
| 75 |
+
y1 = np.maximum(box[0], boxes[:, 0])
|
| 76 |
+
y2 = np.minimum(box[2], boxes[:, 2])
|
| 77 |
+
x1 = np.maximum(box[1], boxes[:, 1])
|
| 78 |
+
x2 = np.minimum(box[3], boxes[:, 3])
|
| 79 |
+
intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
|
| 80 |
+
union = box_area + boxes_area[:] - intersection[:]
|
| 81 |
+
iou = intersection / union
|
| 82 |
+
return iou
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def compute_overlaps(boxes1, boxes2):
|
| 86 |
+
"""Computes IoU overlaps between two sets of boxes.
|
| 87 |
+
boxes1, boxes2: [N, (y1, x1, y2, x2)].
|
| 88 |
+
|
| 89 |
+
For better performance, pass the largest set first and the smaller second.
|
| 90 |
+
"""
|
| 91 |
+
# Areas of anchors and GT boxes
|
| 92 |
+
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
|
| 93 |
+
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
|
| 94 |
+
|
| 95 |
+
# Compute overlaps to generate matrix [boxes1 count, boxes2 count]
|
| 96 |
+
# Each cell contains the IoU value.
|
| 97 |
+
overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
|
| 98 |
+
for i in range(overlaps.shape[1]):
|
| 99 |
+
box2 = boxes2[i]
|
| 100 |
+
overlaps[:, i] = compute_iou(box2, boxes1, area2[i], area1)
|
| 101 |
+
return overlaps
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def compute_overlaps_masks(masks1, masks2):
|
| 105 |
+
"""Computes IoU overlaps between two sets of masks.
|
| 106 |
+
masks1, masks2: [Height, Width, instances]
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# If either set of masks is empty return empty result
|
| 110 |
+
if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
|
| 111 |
+
return np.zeros((masks1.shape[-1], masks2.shape[-1]))
|
| 112 |
+
# flatten masks and compute their areas
|
| 113 |
+
masks1 = np.reshape(masks1 > 0.5, (-1, masks1.shape[-1])).astype(np.float32)
|
| 114 |
+
masks2 = np.reshape(masks2 > 0.5, (-1, masks2.shape[-1])).astype(np.float32)
|
| 115 |
+
area1 = np.sum(masks1, axis=0)
|
| 116 |
+
area2 = np.sum(masks2, axis=0)
|
| 117 |
+
|
| 118 |
+
# intersections and union
|
| 119 |
+
intersections = np.dot(masks1.T, masks2)
|
| 120 |
+
union = area1[:, None] + area2[None, :] - intersections
|
| 121 |
+
overlaps = intersections / union
|
| 122 |
+
|
| 123 |
+
return overlaps
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def non_max_suppression(boxes, scores, threshold):
|
| 127 |
+
"""Performs non-maximum suppression and returns indices of kept boxes.
|
| 128 |
+
boxes: [N, (y1, x1, y2, x2)]. Notice that (y2, x2) lays outside the box.
|
| 129 |
+
scores: 1-D array of box scores.
|
| 130 |
+
threshold: Float. IoU threshold to use for filtering.
|
| 131 |
+
"""
|
| 132 |
+
assert boxes.shape[0] > 0
|
| 133 |
+
if boxes.dtype.kind != "f":
|
| 134 |
+
boxes = boxes.astype(np.float32)
|
| 135 |
+
|
| 136 |
+
# Compute box areas
|
| 137 |
+
y1 = boxes[:, 0]
|
| 138 |
+
x1 = boxes[:, 1]
|
| 139 |
+
y2 = boxes[:, 2]
|
| 140 |
+
x2 = boxes[:, 3]
|
| 141 |
+
area = (y2 - y1) * (x2 - x1)
|
| 142 |
+
|
| 143 |
+
# Get indicies of boxes sorted by scores (highest first)
|
| 144 |
+
ixs = scores.argsort()[::-1]
|
| 145 |
+
|
| 146 |
+
pick = []
|
| 147 |
+
while len(ixs) > 0:
|
| 148 |
+
# Pick top box and add its index to the list
|
| 149 |
+
i = ixs[0]
|
| 150 |
+
pick.append(i)
|
| 151 |
+
# Compute IoU of the picked box with the rest
|
| 152 |
+
iou = compute_iou(boxes[i], boxes[ixs[1:]], area[i], area[ixs[1:]])
|
| 153 |
+
# Identify boxes with IoU over the threshold. This
|
| 154 |
+
# returns indices into ixs[1:], so add 1 to get
|
| 155 |
+
# indices into ixs.
|
| 156 |
+
remove_ixs = np.where(iou > threshold)[0] + 1
|
| 157 |
+
# Remove indices of the picked and overlapped boxes.
|
| 158 |
+
ixs = np.delete(ixs, remove_ixs)
|
| 159 |
+
ixs = np.delete(ixs, 0)
|
| 160 |
+
return np.array(pick, dtype=np.int32)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def apply_box_deltas(boxes, deltas):
|
| 164 |
+
"""Applies the given deltas to the given boxes.
|
| 165 |
+
boxes: [N, (y1, x1, y2, x2)]. Note that (y2, x2) is outside the box.
|
| 166 |
+
deltas: [N, (dy, dx, log(dh), log(dw))]
|
| 167 |
+
"""
|
| 168 |
+
boxes = boxes.astype(np.float32)
|
| 169 |
+
# Convert to y, x, h, w
|
| 170 |
+
height = boxes[:, 2] - boxes[:, 0]
|
| 171 |
+
width = boxes[:, 3] - boxes[:, 1]
|
| 172 |
+
center_y = boxes[:, 0] + 0.5 * height
|
| 173 |
+
center_x = boxes[:, 1] + 0.5 * width
|
| 174 |
+
# Apply deltas
|
| 175 |
+
center_y += deltas[:, 0] * height
|
| 176 |
+
center_x += deltas[:, 1] * width
|
| 177 |
+
height *= np.exp(deltas[:, 2])
|
| 178 |
+
width *= np.exp(deltas[:, 3])
|
| 179 |
+
# Convert back to y1, x1, y2, x2
|
| 180 |
+
y1 = center_y - 0.5 * height
|
| 181 |
+
x1 = center_x - 0.5 * width
|
| 182 |
+
y2 = y1 + height
|
| 183 |
+
x2 = x1 + width
|
| 184 |
+
return np.stack([y1, x1, y2, x2], axis=1)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def box_refinement_graph(box, gt_box):
|
| 188 |
+
"""Compute refinement needed to transform box to gt_box.
|
| 189 |
+
box and gt_box are [N, (y1, x1, y2, x2)]
|
| 190 |
+
"""
|
| 191 |
+
box = tf.cast(box, tf.float32)
|
| 192 |
+
gt_box = tf.cast(gt_box, tf.float32)
|
| 193 |
+
|
| 194 |
+
height = box[:, 2] - box[:, 0]
|
| 195 |
+
width = box[:, 3] - box[:, 1]
|
| 196 |
+
center_y = box[:, 0] + 0.5 * height
|
| 197 |
+
center_x = box[:, 1] + 0.5 * width
|
| 198 |
+
|
| 199 |
+
gt_height = gt_box[:, 2] - gt_box[:, 0]
|
| 200 |
+
gt_width = gt_box[:, 3] - gt_box[:, 1]
|
| 201 |
+
gt_center_y = gt_box[:, 0] + 0.5 * gt_height
|
| 202 |
+
gt_center_x = gt_box[:, 1] + 0.5 * gt_width
|
| 203 |
+
|
| 204 |
+
dy = (gt_center_y - center_y) / height
|
| 205 |
+
dx = (gt_center_x - center_x) / width
|
| 206 |
+
dh = tf.log(gt_height / height)
|
| 207 |
+
dw = tf.log(gt_width / width)
|
| 208 |
+
|
| 209 |
+
result = tf.stack([dy, dx, dh, dw], axis=1)
|
| 210 |
+
return result
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def box_refinement(box, gt_box):
|
| 214 |
+
"""Compute refinement needed to transform box to gt_box.
|
| 215 |
+
box and gt_box are [N, (y1, x1, y2, x2)]. (y2, x2) is
|
| 216 |
+
assumed to be outside the box.
|
| 217 |
+
"""
|
| 218 |
+
box = box.astype(np.float32)
|
| 219 |
+
gt_box = gt_box.astype(np.float32)
|
| 220 |
+
|
| 221 |
+
height = box[:, 2] - box[:, 0]
|
| 222 |
+
width = box[:, 3] - box[:, 1]
|
| 223 |
+
center_y = box[:, 0] + 0.5 * height
|
| 224 |
+
center_x = box[:, 1] + 0.5 * width
|
| 225 |
+
|
| 226 |
+
gt_height = gt_box[:, 2] - gt_box[:, 0]
|
| 227 |
+
gt_width = gt_box[:, 3] - gt_box[:, 1]
|
| 228 |
+
gt_center_y = gt_box[:, 0] + 0.5 * gt_height
|
| 229 |
+
gt_center_x = gt_box[:, 1] + 0.5 * gt_width
|
| 230 |
+
|
| 231 |
+
dy = (gt_center_y - center_y) / height
|
| 232 |
+
dx = (gt_center_x - center_x) / width
|
| 233 |
+
dh = np.log(gt_height / height)
|
| 234 |
+
dw = np.log(gt_width / width)
|
| 235 |
+
|
| 236 |
+
return np.stack([dy, dx, dh, dw], axis=1)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
############################################################
|
| 240 |
+
# Dataset
|
| 241 |
+
############################################################
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Dataset(object):
|
| 245 |
+
"""The base class for dataset classes.
|
| 246 |
+
To use it, create a new class that adds functions specific to the dataset
|
| 247 |
+
you want to use. For example:
|
| 248 |
+
|
| 249 |
+
class CatsAndDogsDataset(Dataset):
|
| 250 |
+
def load_cats_and_dogs(self):
|
| 251 |
+
...
|
| 252 |
+
def load_mask(self, image_id):
|
| 253 |
+
...
|
| 254 |
+
def image_reference(self, image_id):
|
| 255 |
+
...
|
| 256 |
+
|
| 257 |
+
See COCODataset and ShapesDataset as examples.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, class_map=None):
|
| 261 |
+
self._image_ids = []
|
| 262 |
+
self.image_info = []
|
| 263 |
+
# Background is always the first class
|
| 264 |
+
self.class_info = [{"source": "", "id": 0, "name": "BG"}]
|
| 265 |
+
self.source_class_ids = {}
|
| 266 |
+
|
| 267 |
+
def add_class(self, source, class_id, class_name):
|
| 268 |
+
assert "." not in source, "Source name cannot contain a dot"
|
| 269 |
+
# Does the class exist already?
|
| 270 |
+
for info in self.class_info:
|
| 271 |
+
if info["source"] == source and info["id"] == class_id:
|
| 272 |
+
# source.class_id combination already available, skip
|
| 273 |
+
return
|
| 274 |
+
# Add the class
|
| 275 |
+
self.class_info.append(
|
| 276 |
+
{
|
| 277 |
+
"source": source,
|
| 278 |
+
"id": class_id,
|
| 279 |
+
"name": class_name,
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def add_image(self, source, image_id, path, **kwargs):
|
| 284 |
+
image_info = {
|
| 285 |
+
"id": image_id,
|
| 286 |
+
"source": source,
|
| 287 |
+
"path": path,
|
| 288 |
+
}
|
| 289 |
+
image_info.update(kwargs)
|
| 290 |
+
self.image_info.append(image_info)
|
| 291 |
+
|
| 292 |
+
def image_reference(self, image_id):
|
| 293 |
+
"""Return a link to the image in its source Website or details about
|
| 294 |
+
the image that help looking it up or debugging it.
|
| 295 |
+
|
| 296 |
+
Override for your dataset, but pass to this function
|
| 297 |
+
if you encounter images not in your dataset.
|
| 298 |
+
"""
|
| 299 |
+
return ""
|
| 300 |
+
|
| 301 |
+
def prepare(self, class_map=None):
|
| 302 |
+
"""Prepares the Dataset class for use.
|
| 303 |
+
|
| 304 |
+
TODO: class map is not supported yet. When done, it should handle mapping
|
| 305 |
+
classes from different datasets to the same class ID.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
def clean_name(name):
|
| 309 |
+
"""Returns a shorter version of object names for cleaner display."""
|
| 310 |
+
return ",".join(name.split(",")[:1])
|
| 311 |
+
|
| 312 |
+
# Build (or rebuild) everything else from the info dicts.
|
| 313 |
+
self.num_classes = len(self.class_info)
|
| 314 |
+
self.class_ids = np.arange(self.num_classes)
|
| 315 |
+
self.class_names = [clean_name(c["name"]) for c in self.class_info]
|
| 316 |
+
self.num_images = len(self.image_info)
|
| 317 |
+
self._image_ids = np.arange(self.num_images)
|
| 318 |
+
|
| 319 |
+
# Mapping from source class and image IDs to internal IDs
|
| 320 |
+
self.class_from_source_map = {
|
| 321 |
+
"{}.{}".format(info["source"], info["id"]): id
|
| 322 |
+
for info, id in zip(self.class_info, self.class_ids)
|
| 323 |
+
}
|
| 324 |
+
self.image_from_source_map = {
|
| 325 |
+
"{}.{}".format(info["source"], info["id"]): id
|
| 326 |
+
for info, id in zip(self.image_info, self.image_ids)
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
# Map sources to class_ids they support
|
| 330 |
+
self.sources = list(set([i["source"] for i in self.class_info]))
|
| 331 |
+
self.source_class_ids = {}
|
| 332 |
+
# Loop over datasets
|
| 333 |
+
for source in self.sources:
|
| 334 |
+
self.source_class_ids[source] = []
|
| 335 |
+
# Find classes that belong to this dataset
|
| 336 |
+
for i, info in enumerate(self.class_info):
|
| 337 |
+
# Include BG class in all datasets
|
| 338 |
+
if i == 0 or source == info["source"]:
|
| 339 |
+
self.source_class_ids[source].append(i)
|
| 340 |
+
|
| 341 |
+
def map_source_class_id(self, source_class_id):
|
| 342 |
+
"""Takes a source class ID and returns the int class ID assigned to it.
|
| 343 |
+
|
| 344 |
+
For example:
|
| 345 |
+
dataset.map_source_class_id("coco.12") -> 23
|
| 346 |
+
"""
|
| 347 |
+
return self.class_from_source_map[source_class_id]
|
| 348 |
+
|
| 349 |
+
def get_source_class_id(self, class_id, source):
|
| 350 |
+
"""Map an internal class ID to the corresponding class ID in the source dataset."""
|
| 351 |
+
info = self.class_info[class_id]
|
| 352 |
+
assert info["source"] == source
|
| 353 |
+
return info["id"]
|
| 354 |
+
|
| 355 |
+
@property
|
| 356 |
+
def image_ids(self):
|
| 357 |
+
return self._image_ids
|
| 358 |
+
|
| 359 |
+
def source_image_link(self, image_id):
|
| 360 |
+
"""Returns the path or URL to the image.
|
| 361 |
+
Override this to return a URL to the image if it's available online for easy
|
| 362 |
+
debugging.
|
| 363 |
+
"""
|
| 364 |
+
return self.image_info[image_id]["path"]
|
| 365 |
+
|
| 366 |
+
def load_image(self, image_id):
|
| 367 |
+
"""Load the specified image and return a [H,W,3] Numpy array."""
|
| 368 |
+
# Load image
|
| 369 |
+
image = skimage.io.imread(self.image_info[image_id]["path"])
|
| 370 |
+
# If grayscale. Convert to RGB for consistency.
|
| 371 |
+
if image.ndim != 3:
|
| 372 |
+
image = skimage.color.gray2rgb(image)
|
| 373 |
+
# If has an alpha channel, remove it for consistency
|
| 374 |
+
if image.shape[-1] == 4:
|
| 375 |
+
image = image[..., :3]
|
| 376 |
+
return image
|
| 377 |
+
|
| 378 |
+
def load_mask(self, image_id):
|
| 379 |
+
"""Load instance masks for the given image.
|
| 380 |
+
|
| 381 |
+
Different datasets use different ways to store masks. Override this
|
| 382 |
+
method to load instance masks and return them in the form of am
|
| 383 |
+
array of binary masks of shape [height, width, instances].
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
masks: A bool array of shape [height, width, instance count] with
|
| 387 |
+
a binary mask per instance.
|
| 388 |
+
class_ids: a 1D array of class IDs of the instance masks.
|
| 389 |
+
"""
|
| 390 |
+
# Override this function to load a mask from your dataset.
|
| 391 |
+
# Otherwise, it returns an empty mask.
|
| 392 |
+
logging.warning(
|
| 393 |
+
"You are using the default load_mask(), maybe you need to define your own one."
|
| 394 |
+
)
|
| 395 |
+
mask = np.empty([0, 0, 0])
|
| 396 |
+
class_ids = np.empty([0], np.int32)
|
| 397 |
+
return mask, class_ids
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def resize_image(image, min_dim=None, max_dim=None, min_scale=None, mode="square"):
|
| 401 |
+
"""Resizes an image keeping the aspect ratio unchanged.
|
| 402 |
+
|
| 403 |
+
min_dim: if provided, resizes the image such that it's smaller
|
| 404 |
+
dimension == min_dim
|
| 405 |
+
max_dim: if provided, ensures that the image longest side doesn't
|
| 406 |
+
exceed this value.
|
| 407 |
+
min_scale: if provided, ensure that the image is scaled up by at least
|
| 408 |
+
this percent even if min_dim doesn't require it.
|
| 409 |
+
mode: Resizing mode.
|
| 410 |
+
none: No resizing. Return the image unchanged.
|
| 411 |
+
square: Resize and pad with zeros to get a square image
|
| 412 |
+
of size [max_dim, max_dim].
|
| 413 |
+
pad64: Pads width and height with zeros to make them multiples of 64.
|
| 414 |
+
If min_dim or min_scale are provided, it scales the image up
|
| 415 |
+
before padding. max_dim is ignored in this mode.
|
| 416 |
+
The multiple of 64 is needed to ensure smooth scaling of feature
|
| 417 |
+
maps up and down the 6 levels of the FPN pyramid (2**6=64).
|
| 418 |
+
crop: Picks random crops from the image. First, scales the image based
|
| 419 |
+
on min_dim and min_scale, then picks a random crop of
|
| 420 |
+
size min_dim x min_dim. Can be used in training only.
|
| 421 |
+
max_dim is not used in this mode.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
image: the resized image
|
| 425 |
+
window: (y1, x1, y2, x2). If max_dim is provided, padding might
|
| 426 |
+
be inserted in the returned image. If so, this window is the
|
| 427 |
+
coordinates of the image part of the full image (excluding
|
| 428 |
+
the padding). The x2, y2 pixels are not included.
|
| 429 |
+
scale: The scale factor used to resize the image
|
| 430 |
+
padding: Padding added to the image [(top, bottom), (left, right), (0, 0)]
|
| 431 |
+
"""
|
| 432 |
+
# Keep track of image dtype and return results in the same dtype
|
| 433 |
+
image_dtype = image.dtype
|
| 434 |
+
# Default window (y1, x1, y2, x2) and default scale == 1.
|
| 435 |
+
h, w = image.shape[:2]
|
| 436 |
+
window = (0, 0, h, w)
|
| 437 |
+
scale = 1
|
| 438 |
+
padding = [(0, 0), (0, 0), (0, 0)]
|
| 439 |
+
crop = None
|
| 440 |
+
|
| 441 |
+
if mode == "none":
|
| 442 |
+
return image, window, scale, padding, crop
|
| 443 |
+
|
| 444 |
+
# Scale?
|
| 445 |
+
if min_dim:
|
| 446 |
+
# Scale up but not down
|
| 447 |
+
print(min_dim, min(h, w), type(min_dim), type(min(h, w)))
|
| 448 |
+
scale = max(1, min_dim / min(h, w))
|
| 449 |
+
if min_scale and scale < min_scale:
|
| 450 |
+
scale = min_scale
|
| 451 |
+
|
| 452 |
+
# Does it exceed max dim?
|
| 453 |
+
if max_dim and mode == "square":
|
| 454 |
+
image_max = max(h, w)
|
| 455 |
+
if round(image_max * scale) > max_dim:
|
| 456 |
+
scale = max_dim / image_max
|
| 457 |
+
|
| 458 |
+
# Resize image using bilinear interpolation
|
| 459 |
+
if scale != 1:
|
| 460 |
+
image = resize(image, (round(h * scale), round(w * scale)), preserve_range=True)
|
| 461 |
+
|
| 462 |
+
# Need padding or cropping?
|
| 463 |
+
if mode == "square":
|
| 464 |
+
# Get new height and width
|
| 465 |
+
h, w = image.shape[:2]
|
| 466 |
+
top_pad = (max_dim - h) // 2
|
| 467 |
+
bottom_pad = max_dim - h - top_pad
|
| 468 |
+
left_pad = (max_dim - w) // 2
|
| 469 |
+
right_pad = max_dim - w - left_pad
|
| 470 |
+
padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
|
| 471 |
+
image = np.pad(image, padding, mode="constant", constant_values=0)
|
| 472 |
+
window = (top_pad, left_pad, h + top_pad, w + left_pad)
|
| 473 |
+
elif mode == "pad64":
|
| 474 |
+
h, w = image.shape[:2]
|
| 475 |
+
# Both sides must be divisible by 64
|
| 476 |
+
assert min_dim % 64 == 0, "Minimum dimension must be a multiple of 64"
|
| 477 |
+
# Height
|
| 478 |
+
if h % 64 > 0:
|
| 479 |
+
max_h = h - (h % 64) + 64
|
| 480 |
+
top_pad = (max_h - h) // 2
|
| 481 |
+
bottom_pad = max_h - h - top_pad
|
| 482 |
+
else:
|
| 483 |
+
top_pad = bottom_pad = 0
|
| 484 |
+
# Width
|
| 485 |
+
if w % 64 > 0:
|
| 486 |
+
max_w = w - (w % 64) + 64
|
| 487 |
+
left_pad = (max_w - w) // 2
|
| 488 |
+
right_pad = max_w - w - left_pad
|
| 489 |
+
else:
|
| 490 |
+
left_pad = right_pad = 0
|
| 491 |
+
padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
|
| 492 |
+
image = np.pad(image, padding, mode="constant", constant_values=0)
|
| 493 |
+
window = (top_pad, left_pad, h + top_pad, w + left_pad)
|
| 494 |
+
elif mode == "crop":
|
| 495 |
+
# Pick a random crop
|
| 496 |
+
h, w = image.shape[:2]
|
| 497 |
+
y = random.randint(0, (h - min_dim))
|
| 498 |
+
x = random.randint(0, (w - min_dim))
|
| 499 |
+
crop = (y, x, min_dim, min_dim)
|
| 500 |
+
image = image[y : y + min_dim, x : x + min_dim]
|
| 501 |
+
window = (0, 0, min_dim, min_dim)
|
| 502 |
+
else:
|
| 503 |
+
raise Exception("Mode {} not supported".format(mode))
|
| 504 |
+
return image.astype(image_dtype), window, scale, padding, crop
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def resize_mask(mask, scale, padding, crop=None):
|
| 508 |
+
"""Resizes a mask using the given scale and padding.
|
| 509 |
+
Typically, you get the scale and padding from resize_image() to
|
| 510 |
+
ensure both, the image and the mask, are resized consistently.
|
| 511 |
+
|
| 512 |
+
scale: mask scaling factor
|
| 513 |
+
padding: Padding to add to the mask in the form
|
| 514 |
+
[(top, bottom), (left, right), (0, 0)]
|
| 515 |
+
"""
|
| 516 |
+
# Suppress warning from scipy 0.13.0, the output shape of zoom() is
|
| 517 |
+
# calculated with round() instead of int()
|
| 518 |
+
with warnings.catch_warnings():
|
| 519 |
+
warnings.simplefilter("ignore")
|
| 520 |
+
mask = scipy.ndimage.zoom(mask, zoom=[scale, scale, 1], order=0)
|
| 521 |
+
if crop is not None:
|
| 522 |
+
y, x, h, w = crop
|
| 523 |
+
mask = mask[y : y + h, x : x + w]
|
| 524 |
+
else:
|
| 525 |
+
mask = np.pad(mask, padding, mode="constant", constant_values=0)
|
| 526 |
+
return mask
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def minimize_mask(bbox, mask, mini_shape):
|
| 530 |
+
"""Resize masks to a smaller version to reduce memory load.
|
| 531 |
+
Mini-masks can be resized back to image scale using expand_masks()
|
| 532 |
+
|
| 533 |
+
See inspect_data.ipynb notebook for more details.
|
| 534 |
+
"""
|
| 535 |
+
mini_mask = np.zeros(mini_shape + (mask.shape[-1],), dtype=bool)
|
| 536 |
+
for i in range(mask.shape[-1]):
|
| 537 |
+
# Pick slice and cast to bool in case load_mask() returned wrong dtype
|
| 538 |
+
m = mask[:, :, i].astype(bool)
|
| 539 |
+
y1, x1, y2, x2 = bbox[i][:4]
|
| 540 |
+
m = m[y1:y2, x1:x2]
|
| 541 |
+
if m.size == 0:
|
| 542 |
+
raise Exception("Invalid bounding box with area of zero")
|
| 543 |
+
# Resize with bilinear interpolation
|
| 544 |
+
m = resize(m, mini_shape)
|
| 545 |
+
mini_mask[:, :, i] = np.around(m).astype(np.bool)
|
| 546 |
+
return mini_mask
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def expand_mask(bbox, mini_mask, image_shape):
|
| 550 |
+
"""Resizes mini masks back to image size. Reverses the change
|
| 551 |
+
of minimize_mask().
|
| 552 |
+
|
| 553 |
+
See inspect_data.ipynb notebook for more details.
|
| 554 |
+
"""
|
| 555 |
+
mask = np.zeros(image_shape[:2] + (mini_mask.shape[-1],), dtype=bool)
|
| 556 |
+
for i in range(mask.shape[-1]):
|
| 557 |
+
m = mini_mask[:, :, i]
|
| 558 |
+
y1, x1, y2, x2 = bbox[i][:4]
|
| 559 |
+
h = y2 - y1
|
| 560 |
+
w = x2 - x1
|
| 561 |
+
# Resize with bilinear interpolation
|
| 562 |
+
m = resize(m, (h, w))
|
| 563 |
+
mask[y1:y2, x1:x2, i] = np.around(m).astype(np.bool)
|
| 564 |
+
return mask
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# TODO: Build and use this function to reduce code duplication
|
| 568 |
+
def mold_mask(mask, config):
|
| 569 |
+
pass
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def unmold_mask(mask, bbox, image_shape):
|
| 573 |
+
"""Converts a mask generated by the neural network to a format similar
|
| 574 |
+
to its original shape.
|
| 575 |
+
mask: [height, width] of type float. A small, typically 28x28 mask.
|
| 576 |
+
bbox: [y1, x1, y2, x2]. The box to fit the mask in.
|
| 577 |
+
|
| 578 |
+
Returns a binary mask with the same size as the original image.
|
| 579 |
+
"""
|
| 580 |
+
threshold = 0.5
|
| 581 |
+
y1, x1, y2, x2 = bbox
|
| 582 |
+
mask = resize(mask, (y2 - y1, x2 - x1))
|
| 583 |
+
mask = np.where(mask >= threshold, 1, 0).astype(np.bool)
|
| 584 |
+
|
| 585 |
+
# Put the mask in the right location.
|
| 586 |
+
full_mask = np.zeros(image_shape[:2], dtype=np.bool)
|
| 587 |
+
full_mask[y1:y2, x1:x2] = mask
|
| 588 |
+
return full_mask
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
############################################################
|
| 592 |
+
# Anchors
|
| 593 |
+
############################################################
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
|
| 597 |
+
"""
|
| 598 |
+
scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
|
| 599 |
+
ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
|
| 600 |
+
shape: [height, width] spatial shape of the feature map over which
|
| 601 |
+
to generate anchors.
|
| 602 |
+
feature_stride: Stride of the feature map relative to the image in pixels.
|
| 603 |
+
anchor_stride: Stride of anchors on the feature map. For example, if the
|
| 604 |
+
value is 2 then generate anchors for every other feature map pixel.
|
| 605 |
+
"""
|
| 606 |
+
# Get all combinations of scales and ratios
|
| 607 |
+
scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
|
| 608 |
+
scales = scales.flatten()
|
| 609 |
+
ratios = ratios.flatten()
|
| 610 |
+
|
| 611 |
+
# Enumerate heights and widths from scales and ratios
|
| 612 |
+
heights = scales / np.sqrt(ratios)
|
| 613 |
+
widths = scales * np.sqrt(ratios)
|
| 614 |
+
|
| 615 |
+
# Enumerate shifts in feature space
|
| 616 |
+
shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
|
| 617 |
+
shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
|
| 618 |
+
shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
|
| 619 |
+
|
| 620 |
+
# Enumerate combinations of shifts, widths, and heights
|
| 621 |
+
box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
|
| 622 |
+
box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
|
| 623 |
+
|
| 624 |
+
# Reshape to get a list of (y, x) and a list of (h, w)
|
| 625 |
+
box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
|
| 626 |
+
box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
|
| 627 |
+
|
| 628 |
+
# Convert to corner coordinates (y1, x1, y2, x2)
|
| 629 |
+
boxes = np.concatenate(
|
| 630 |
+
[box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1
|
| 631 |
+
)
|
| 632 |
+
return boxes
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def generate_pyramid_anchors(
|
| 636 |
+
scales, ratios, feature_shapes, feature_strides, anchor_stride
|
| 637 |
+
):
|
| 638 |
+
"""Generate anchors at different levels of a feature pyramid. Each scale
|
| 639 |
+
is associated with a level of the pyramid, but each ratio is used in
|
| 640 |
+
all levels of the pyramid.
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
anchors: [N, (y1, x1, y2, x2)]. All generated anchors in one array. Sorted
|
| 644 |
+
with the same order of the given scales. So, anchors of scale[0] come
|
| 645 |
+
first, then anchors of scale[1], and so on.
|
| 646 |
+
"""
|
| 647 |
+
# Anchors
|
| 648 |
+
# [anchor_count, (y1, x1, y2, x2)]
|
| 649 |
+
anchors = []
|
| 650 |
+
for i in range(len(scales)):
|
| 651 |
+
anchors.append(
|
| 652 |
+
generate_anchors(
|
| 653 |
+
scales[i], ratios, feature_shapes[i], feature_strides[i], anchor_stride
|
| 654 |
+
)
|
| 655 |
+
)
|
| 656 |
+
return np.concatenate(anchors, axis=0)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
############################################################
|
| 660 |
+
# Miscellaneous
|
| 661 |
+
############################################################
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def trim_zeros(x):
|
| 665 |
+
"""It's common to have tensors larger than the available data and
|
| 666 |
+
pad with zeros. This function removes rows that are all zeros.
|
| 667 |
+
|
| 668 |
+
x: [rows, columns].
|
| 669 |
+
"""
|
| 670 |
+
assert len(x.shape) == 2
|
| 671 |
+
return x[~np.all(x == 0, axis=1)]
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def compute_matches(
|
| 675 |
+
gt_boxes,
|
| 676 |
+
gt_class_ids,
|
| 677 |
+
gt_masks,
|
| 678 |
+
pred_boxes,
|
| 679 |
+
pred_class_ids,
|
| 680 |
+
pred_scores,
|
| 681 |
+
pred_masks,
|
| 682 |
+
iou_threshold=0.5,
|
| 683 |
+
score_threshold=0.0,
|
| 684 |
+
):
|
| 685 |
+
"""Finds matches between prediction and ground truth instances.
|
| 686 |
+
|
| 687 |
+
Returns:
|
| 688 |
+
gt_match: 1-D array. For each GT box it has the index of the matched
|
| 689 |
+
predicted box.
|
| 690 |
+
pred_match: 1-D array. For each predicted box, it has the index of
|
| 691 |
+
the matched ground truth box.
|
| 692 |
+
overlaps: [pred_boxes, gt_boxes] IoU overlaps.
|
| 693 |
+
"""
|
| 694 |
+
# Trim zero padding
|
| 695 |
+
# TODO: cleaner to do zero unpadding upstream
|
| 696 |
+
gt_boxes = trim_zeros(gt_boxes)
|
| 697 |
+
gt_masks = gt_masks[..., : gt_boxes.shape[0]]
|
| 698 |
+
pred_boxes = trim_zeros(pred_boxes)
|
| 699 |
+
pred_scores = pred_scores[: pred_boxes.shape[0]]
|
| 700 |
+
# Sort predictions by score from high to low
|
| 701 |
+
indices = np.argsort(pred_scores)[::-1]
|
| 702 |
+
pred_boxes = pred_boxes[indices]
|
| 703 |
+
pred_class_ids = pred_class_ids[indices]
|
| 704 |
+
pred_scores = pred_scores[indices]
|
| 705 |
+
pred_masks = pred_masks[..., indices]
|
| 706 |
+
|
| 707 |
+
# Compute IoU overlaps [pred_masks, gt_masks]
|
| 708 |
+
overlaps = compute_overlaps_masks(pred_masks, gt_masks)
|
| 709 |
+
|
| 710 |
+
# Loop through predictions and find matching ground truth boxes
|
| 711 |
+
match_count = 0
|
| 712 |
+
pred_match = -1 * np.ones([pred_boxes.shape[0]])
|
| 713 |
+
gt_match = -1 * np.ones([gt_boxes.shape[0]])
|
| 714 |
+
for i in range(len(pred_boxes)):
|
| 715 |
+
# Find best matching ground truth box
|
| 716 |
+
# 1. Sort matches by score
|
| 717 |
+
sorted_ixs = np.argsort(overlaps[i])[::-1]
|
| 718 |
+
# 2. Remove low scores
|
| 719 |
+
low_score_idx = np.where(overlaps[i, sorted_ixs] < score_threshold)[0]
|
| 720 |
+
if low_score_idx.size > 0:
|
| 721 |
+
sorted_ixs = sorted_ixs[: low_score_idx[0]]
|
| 722 |
+
# 3. Find the match
|
| 723 |
+
for j in sorted_ixs:
|
| 724 |
+
# If ground truth box is already matched, go to next one
|
| 725 |
+
if gt_match[j] > -1:
|
| 726 |
+
continue
|
| 727 |
+
# If we reach IoU smaller than the threshold, end the loop
|
| 728 |
+
iou = overlaps[i, j]
|
| 729 |
+
if iou < iou_threshold:
|
| 730 |
+
break
|
| 731 |
+
# Do we have a match?
|
| 732 |
+
if pred_class_ids[i] == gt_class_ids[j]:
|
| 733 |
+
match_count += 1
|
| 734 |
+
gt_match[j] = i
|
| 735 |
+
pred_match[i] = j
|
| 736 |
+
break
|
| 737 |
+
|
| 738 |
+
return gt_match, pred_match, overlaps
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def compute_ap(
|
| 742 |
+
gt_boxes,
|
| 743 |
+
gt_class_ids,
|
| 744 |
+
gt_masks,
|
| 745 |
+
pred_boxes,
|
| 746 |
+
pred_class_ids,
|
| 747 |
+
pred_scores,
|
| 748 |
+
pred_masks,
|
| 749 |
+
iou_threshold=0.5,
|
| 750 |
+
):
|
| 751 |
+
"""Compute Average Precision at a set IoU threshold (default 0.5).
|
| 752 |
+
|
| 753 |
+
Returns:
|
| 754 |
+
mAP: Mean Average Precision
|
| 755 |
+
precisions: List of precisions at different class score thresholds.
|
| 756 |
+
recalls: List of recall values at different class score thresholds.
|
| 757 |
+
overlaps: [pred_boxes, gt_boxes] IoU overlaps.
|
| 758 |
+
"""
|
| 759 |
+
# Get matches and overlaps
|
| 760 |
+
gt_match, pred_match, overlaps = compute_matches(
|
| 761 |
+
gt_boxes,
|
| 762 |
+
gt_class_ids,
|
| 763 |
+
gt_masks,
|
| 764 |
+
pred_boxes,
|
| 765 |
+
pred_class_ids,
|
| 766 |
+
pred_scores,
|
| 767 |
+
pred_masks,
|
| 768 |
+
iou_threshold,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# Compute precision and recall at each prediction box step
|
| 772 |
+
precisions = np.cumsum(pred_match > -1) / (np.arange(len(pred_match)) + 1)
|
| 773 |
+
recalls = np.cumsum(pred_match > -1).astype(np.float32) / len(gt_match)
|
| 774 |
+
|
| 775 |
+
# Pad with start and end values to simplify the math
|
| 776 |
+
precisions = np.concatenate([[0], precisions, [0]])
|
| 777 |
+
recalls = np.concatenate([[0], recalls, [1]])
|
| 778 |
+
|
| 779 |
+
# Ensure precision values decrease but don't increase. This way, the
|
| 780 |
+
# precision value at each recall threshold is the maximum it can be
|
| 781 |
+
# for all following recall thresholds, as specified by the VOC paper.
|
| 782 |
+
for i in range(len(precisions) - 2, -1, -1):
|
| 783 |
+
precisions[i] = np.maximum(precisions[i], precisions[i + 1])
|
| 784 |
+
|
| 785 |
+
# Compute mean AP over recall range
|
| 786 |
+
indices = np.where(recalls[:-1] != recalls[1:])[0] + 1
|
| 787 |
+
mAP = np.sum((recalls[indices] - recalls[indices - 1]) * precisions[indices])
|
| 788 |
+
|
| 789 |
+
return mAP, precisions, recalls, overlaps
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def compute_ap_range(
|
| 793 |
+
gt_box,
|
| 794 |
+
gt_class_id,
|
| 795 |
+
gt_mask,
|
| 796 |
+
pred_box,
|
| 797 |
+
pred_class_id,
|
| 798 |
+
pred_score,
|
| 799 |
+
pred_mask,
|
| 800 |
+
iou_thresholds=None,
|
| 801 |
+
verbose=1,
|
| 802 |
+
):
|
| 803 |
+
"""Compute AP over a range or IoU thresholds. Default range is 0.5-0.95."""
|
| 804 |
+
# Default is 0.5 to 0.95 with increments of 0.05
|
| 805 |
+
iou_thresholds = iou_thresholds or np.arange(0.5, 1.0, 0.05)
|
| 806 |
+
|
| 807 |
+
# Compute AP over range of IoU thresholds
|
| 808 |
+
AP = []
|
| 809 |
+
for iou_threshold in iou_thresholds:
|
| 810 |
+
ap, precisions, recalls, overlaps = compute_ap(
|
| 811 |
+
gt_box,
|
| 812 |
+
gt_class_id,
|
| 813 |
+
gt_mask,
|
| 814 |
+
pred_box,
|
| 815 |
+
pred_class_id,
|
| 816 |
+
pred_score,
|
| 817 |
+
pred_mask,
|
| 818 |
+
iou_threshold=iou_threshold,
|
| 819 |
+
)
|
| 820 |
+
if verbose:
|
| 821 |
+
print("AP @{:.2f}:\t {:.3f}".format(iou_threshold, ap))
|
| 822 |
+
AP.append(ap)
|
| 823 |
+
AP = np.array(AP).mean()
|
| 824 |
+
if verbose:
|
| 825 |
+
print(
|
| 826 |
+
"AP @{:.2f}-{:.2f}:\t {:.3f}".format(
|
| 827 |
+
iou_thresholds[0], iou_thresholds[-1], AP
|
| 828 |
+
)
|
| 829 |
+
)
|
| 830 |
+
return AP
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
def compute_recall(pred_boxes, gt_boxes, iou):
|
| 834 |
+
"""Compute the recall at the given IoU threshold. It's an indication
|
| 835 |
+
of how many GT boxes were found by the given prediction boxes.
|
| 836 |
+
|
| 837 |
+
pred_boxes: [N, (y1, x1, y2, x2)] in image coordinates
|
| 838 |
+
gt_boxes: [N, (y1, x1, y2, x2)] in image coordinates
|
| 839 |
+
"""
|
| 840 |
+
# Measure overlaps
|
| 841 |
+
overlaps = compute_overlaps(pred_boxes, gt_boxes)
|
| 842 |
+
iou_max = np.max(overlaps, axis=1)
|
| 843 |
+
iou_argmax = np.argmax(overlaps, axis=1)
|
| 844 |
+
positive_ids = np.where(iou_max >= iou)[0]
|
| 845 |
+
matched_gt_boxes = iou_argmax[positive_ids]
|
| 846 |
+
|
| 847 |
+
recall = len(set(matched_gt_boxes)) / gt_boxes.shape[0]
|
| 848 |
+
return recall, positive_ids
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
# ## Batch Slicing
|
| 852 |
+
# Some custom layers support a batch size of 1 only, and require a lot of work
|
| 853 |
+
# to support batches greater than 1. This function slices an input tensor
|
| 854 |
+
# across the batch dimension and feeds batches of size 1. Effectively,
|
| 855 |
+
# an easy way to support batches > 1 quickly with little code modification.
|
| 856 |
+
# In the long run, it's more efficient to modify the code to support large
|
| 857 |
+
# batches and getting rid of this function. Consider this a temporary solution
|
| 858 |
+
def batch_slice(inputs, graph_fn, batch_size, names=None):
|
| 859 |
+
"""Splits inputs into slices and feeds each slice to a copy of the given
|
| 860 |
+
computation graph and then combines the results. It allows you to run a
|
| 861 |
+
graph on a batch of inputs even if the graph is written to support one
|
| 862 |
+
instance only.
|
| 863 |
+
|
| 864 |
+
inputs: list of tensors. All must have the same first dimension length
|
| 865 |
+
graph_fn: A function that returns a TF tensor that's part of a graph.
|
| 866 |
+
batch_size: number of slices to divide the data into.
|
| 867 |
+
names: If provided, assigns names to the resulting tensors.
|
| 868 |
+
"""
|
| 869 |
+
if not isinstance(inputs, list):
|
| 870 |
+
inputs = [inputs]
|
| 871 |
+
|
| 872 |
+
outputs = []
|
| 873 |
+
for i in range(batch_size):
|
| 874 |
+
inputs_slice = [x[i] for x in inputs]
|
| 875 |
+
output_slice = graph_fn(*inputs_slice)
|
| 876 |
+
if not isinstance(output_slice, (tuple, list)):
|
| 877 |
+
output_slice = [output_slice]
|
| 878 |
+
outputs.append(output_slice)
|
| 879 |
+
# Change outputs from a list of slices where each is
|
| 880 |
+
# a list of outputs to a list of outputs and each has
|
| 881 |
+
# a list of slices
|
| 882 |
+
outputs = list(zip(*outputs))
|
| 883 |
+
|
| 884 |
+
if names is None:
|
| 885 |
+
names = [None] * len(outputs)
|
| 886 |
+
|
| 887 |
+
result = [tf.stack(o, axis=0, name=n) for o, n in zip(outputs, names)]
|
| 888 |
+
if len(result) == 1:
|
| 889 |
+
result = result[0]
|
| 890 |
+
|
| 891 |
+
return result
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def download_trained_weights(coco_model_path, verbose=1):
|
| 895 |
+
"""Download COCO trained weights from Releases.
|
| 896 |
+
|
| 897 |
+
coco_model_path: local path of COCO trained weights
|
| 898 |
+
"""
|
| 899 |
+
if verbose > 0:
|
| 900 |
+
print("Downloading pretrained model to " + coco_model_path + " ...")
|
| 901 |
+
with urllib.request.urlopen(COCO_MODEL_URL) as resp, open(
|
| 902 |
+
coco_model_path, "wb"
|
| 903 |
+
) as out:
|
| 904 |
+
shutil.copyfileobj(resp, out)
|
| 905 |
+
if verbose > 0:
|
| 906 |
+
print("... done downloading pretrained model!")
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def norm_boxes(boxes, shape):
|
| 910 |
+
"""Converts boxes from pixel coordinates to normalized coordinates.
|
| 911 |
+
boxes: [N, (y1, x1, y2, x2)] in pixel coordinates
|
| 912 |
+
shape: [..., (height, width)] in pixels
|
| 913 |
+
|
| 914 |
+
Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
|
| 915 |
+
coordinates it's inside the box.
|
| 916 |
+
|
| 917 |
+
Returns:
|
| 918 |
+
[N, (y1, x1, y2, x2)] in normalized coordinates
|
| 919 |
+
"""
|
| 920 |
+
h, w = shape
|
| 921 |
+
scale = np.array([h - 1, w - 1, h - 1, w - 1])
|
| 922 |
+
shift = np.array([0, 0, 1, 1])
|
| 923 |
+
return np.divide((boxes - shift), scale).astype(np.float32)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def denorm_boxes(boxes, shape):
|
| 927 |
+
"""Converts boxes from normalized coordinates to pixel coordinates.
|
| 928 |
+
boxes: [N, (y1, x1, y2, x2)] in normalized coordinates
|
| 929 |
+
shape: [..., (height, width)] in pixels
|
| 930 |
+
|
| 931 |
+
Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
|
| 932 |
+
coordinates it's inside the box.
|
| 933 |
+
|
| 934 |
+
Returns:
|
| 935 |
+
[N, (y1, x1, y2, x2)] in pixel coordinates
|
| 936 |
+
"""
|
| 937 |
+
h, w = shape
|
| 938 |
+
scale = np.array([h - 1, w - 1, h - 1, w - 1])
|
| 939 |
+
shift = np.array([0, 0, 1, 1])
|
| 940 |
+
return np.around(np.multiply(boxes, scale) + shift).astype(np.int32)
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def resize(
|
| 944 |
+
image,
|
| 945 |
+
output_shape,
|
| 946 |
+
order=1,
|
| 947 |
+
mode="constant",
|
| 948 |
+
cval=0,
|
| 949 |
+
clip=True,
|
| 950 |
+
preserve_range=False,
|
| 951 |
+
anti_aliasing=False,
|
| 952 |
+
anti_aliasing_sigma=None,
|
| 953 |
+
):
|
| 954 |
+
"""A wrapper for Scikit-Image resize().
|
| 955 |
+
|
| 956 |
+
Scikit-Image generates warnings on every call to resize() if it doesn't
|
| 957 |
+
receive the right parameters. The right parameters depend on the version
|
| 958 |
+
of skimage. This solves the problem by using different parameters per
|
| 959 |
+
version. And it provides a central place to control resizing defaults.
|
| 960 |
+
"""
|
| 961 |
+
if LooseVersion(skimage.__version__) >= LooseVersion("0.14"):
|
| 962 |
+
# New in 0.14: anti_aliasing. Default it to False for backward
|
| 963 |
+
# compatibility with skimage 0.13.
|
| 964 |
+
return skimage.transform.resize(
|
| 965 |
+
image,
|
| 966 |
+
output_shape,
|
| 967 |
+
order=order,
|
| 968 |
+
mode=mode,
|
| 969 |
+
cval=cval,
|
| 970 |
+
clip=clip,
|
| 971 |
+
preserve_range=preserve_range,
|
| 972 |
+
anti_aliasing=anti_aliasing,
|
| 973 |
+
anti_aliasing_sigma=anti_aliasing_sigma,
|
| 974 |
+
)
|
| 975 |
+
else:
|
| 976 |
+
return skimage.transform.resize(
|
| 977 |
+
image,
|
| 978 |
+
output_shape,
|
| 979 |
+
order=order,
|
| 980 |
+
mode=mode,
|
| 981 |
+
cval=cval,
|
| 982 |
+
clip=clip,
|
| 983 |
+
preserve_range=preserve_range,
|
| 984 |
+
)
|
mrcnn/visualize.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask R-CNN
|
| 3 |
+
Display and Visualization Functions.
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2017 Matterport, Inc.
|
| 6 |
+
Licensed under the MIT License (see LICENSE for details)
|
| 7 |
+
Written by Waleed Abdulla
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import colorsys
|
| 11 |
+
import itertools
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import IPython.display
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import numpy as np
|
| 19 |
+
from matplotlib import lines
|
| 20 |
+
from matplotlib import patches
|
| 21 |
+
from matplotlib.patches import Polygon
|
| 22 |
+
from skimage.measure import find_contours
|
| 23 |
+
|
| 24 |
+
# Root directory of the project
|
| 25 |
+
ROOT_DIR = os.path.abspath("../")
|
| 26 |
+
|
| 27 |
+
# Import Mask RCNN
|
| 28 |
+
sys.path.append(ROOT_DIR) # To find local version of the library
|
| 29 |
+
from mrcnn import utils
|
| 30 |
+
|
| 31 |
+
############################################################
|
| 32 |
+
# Visualization
|
| 33 |
+
############################################################
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def display_images(
|
| 37 |
+
images, titles=None, cols=4, cmap=None, norm=None, interpolation=None
|
| 38 |
+
):
|
| 39 |
+
"""Display the given set of images, optionally with titles.
|
| 40 |
+
images: list or array of image tensors in HWC format.
|
| 41 |
+
titles: optional. A list of titles to display with each image.
|
| 42 |
+
cols: number of images per row
|
| 43 |
+
cmap: Optional. Color map to use. For example, "Blues".
|
| 44 |
+
norm: Optional. A Normalize instance to map values to colors.
|
| 45 |
+
interpolation: Optional. Image interpolation to use for display.
|
| 46 |
+
"""
|
| 47 |
+
titles = titles if titles is not None else [""] * len(images)
|
| 48 |
+
rows = len(images) // cols + 1
|
| 49 |
+
plt.figure(figsize=(14, 14 * rows // cols))
|
| 50 |
+
i = 1
|
| 51 |
+
for image, title in zip(images, titles):
|
| 52 |
+
plt.subplot(rows, cols, i)
|
| 53 |
+
plt.title(title, fontsize=9)
|
| 54 |
+
plt.axis("off")
|
| 55 |
+
plt.imshow(
|
| 56 |
+
image.astype(np.uint8), cmap=cmap, norm=norm, interpolation=interpolation
|
| 57 |
+
)
|
| 58 |
+
i += 1
|
| 59 |
+
plt.show()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def random_colors(N, bright=True):
|
| 63 |
+
"""
|
| 64 |
+
Generate random colors.
|
| 65 |
+
To get visually distinct colors, generate them in HSV space then
|
| 66 |
+
convert to RGB.
|
| 67 |
+
"""
|
| 68 |
+
brightness = 1.0 if bright else 0.7
|
| 69 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
| 70 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
| 71 |
+
random.shuffle(colors)
|
| 72 |
+
return colors
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def apply_mask(image, mask, color, alpha=0.5):
|
| 76 |
+
"""Apply the given mask to the image."""
|
| 77 |
+
for c in range(3):
|
| 78 |
+
image[:, :, c] = np.where(
|
| 79 |
+
mask == 1,
|
| 80 |
+
image[:, :, c] * (1 - alpha) + alpha * color[c] * 255,
|
| 81 |
+
image[:, :, c],
|
| 82 |
+
)
|
| 83 |
+
return image
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def display_instances(
|
| 87 |
+
image,
|
| 88 |
+
boxes,
|
| 89 |
+
masks,
|
| 90 |
+
class_ids,
|
| 91 |
+
class_names,
|
| 92 |
+
scores=None,
|
| 93 |
+
title="",
|
| 94 |
+
figsize=(16, 16),
|
| 95 |
+
ax=None,
|
| 96 |
+
show_mask=True,
|
| 97 |
+
show_bbox=True,
|
| 98 |
+
colors=None,
|
| 99 |
+
captions=None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
|
| 103 |
+
masks: [height, width, num_instances]
|
| 104 |
+
class_ids: [num_instances]
|
| 105 |
+
class_names: list of class names of the dataset
|
| 106 |
+
scores: (optional) confidence scores for each box
|
| 107 |
+
title: (optional) Figure title
|
| 108 |
+
show_mask, show_bbox: To show masks and bounding boxes or not
|
| 109 |
+
figsize: (optional) the size of the image
|
| 110 |
+
colors: (optional) An array or colors to use with each object
|
| 111 |
+
captions: (optional) A list of strings to use as captions for each object
|
| 112 |
+
"""
|
| 113 |
+
# Number of instances
|
| 114 |
+
N = boxes.shape[0]
|
| 115 |
+
if not N:
|
| 116 |
+
print("\n*** No instances to display *** \n")
|
| 117 |
+
else:
|
| 118 |
+
assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
|
| 119 |
+
|
| 120 |
+
# If no axis is passed, create one and automatically call show()
|
| 121 |
+
auto_show = False
|
| 122 |
+
if not ax:
|
| 123 |
+
_, ax = plt.subplots(1, figsize=figsize)
|
| 124 |
+
auto_show = True
|
| 125 |
+
|
| 126 |
+
# Generate random colors
|
| 127 |
+
colors = colors or random_colors(N)
|
| 128 |
+
|
| 129 |
+
# Show area outside image boundaries.
|
| 130 |
+
height, width = image.shape[:2]
|
| 131 |
+
ax.set_ylim(height + 10, -10)
|
| 132 |
+
ax.set_xlim(-10, width + 10)
|
| 133 |
+
ax.axis("off")
|
| 134 |
+
ax.set_title(title)
|
| 135 |
+
|
| 136 |
+
masked_image = image.astype(np.uint32).copy()
|
| 137 |
+
for i in range(N):
|
| 138 |
+
color = colors[i]
|
| 139 |
+
|
| 140 |
+
# Bounding box
|
| 141 |
+
if not np.any(boxes[i]):
|
| 142 |
+
# Skip this instance. Has no bbox. Likely lost in image cropping.
|
| 143 |
+
continue
|
| 144 |
+
y1, x1, y2, x2 = boxes[i]
|
| 145 |
+
if show_bbox:
|
| 146 |
+
p = patches.Rectangle(
|
| 147 |
+
(x1, y1),
|
| 148 |
+
x2 - x1,
|
| 149 |
+
y2 - y1,
|
| 150 |
+
linewidth=2,
|
| 151 |
+
alpha=0.7,
|
| 152 |
+
linestyle="dashed",
|
| 153 |
+
edgecolor=color,
|
| 154 |
+
facecolor="none",
|
| 155 |
+
)
|
| 156 |
+
ax.add_patch(p)
|
| 157 |
+
|
| 158 |
+
# Label
|
| 159 |
+
if not captions:
|
| 160 |
+
class_id = class_ids[i]
|
| 161 |
+
score = scores[i] if scores is not None else None
|
| 162 |
+
label = class_names[class_id]
|
| 163 |
+
caption = "{} {:.3f}".format(label, score) if score else label
|
| 164 |
+
else:
|
| 165 |
+
caption = captions[i]
|
| 166 |
+
ax.text(x1, y1 + 8, caption, color="w", size=11, backgroundcolor="none")
|
| 167 |
+
|
| 168 |
+
# Mask
|
| 169 |
+
mask = masks[:, :, i]
|
| 170 |
+
if show_mask:
|
| 171 |
+
masked_image = apply_mask(masked_image, mask, color)
|
| 172 |
+
|
| 173 |
+
# Mask Polygon
|
| 174 |
+
# Pad to ensure proper polygons for masks that touch image edges.
|
| 175 |
+
padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
|
| 176 |
+
padded_mask[1:-1, 1:-1] = mask
|
| 177 |
+
contours = find_contours(padded_mask, 0.5)
|
| 178 |
+
for verts in contours:
|
| 179 |
+
# Subtract the padding and flip (y, x) to (x, y)
|
| 180 |
+
verts = np.fliplr(verts) - 1
|
| 181 |
+
p = Polygon(verts, facecolor="none", edgecolor=color)
|
| 182 |
+
ax.add_patch(p)
|
| 183 |
+
|
| 184 |
+
# ax.imshow(masked_image.astype(np.uint8))
|
| 185 |
+
|
| 186 |
+
if auto_show:
|
| 187 |
+
plt.show()
|
| 188 |
+
|
| 189 |
+
return masked_image.astype(np.uint8)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def display_differences(
|
| 193 |
+
image,
|
| 194 |
+
gt_box,
|
| 195 |
+
gt_class_id,
|
| 196 |
+
gt_mask,
|
| 197 |
+
pred_box,
|
| 198 |
+
pred_class_id,
|
| 199 |
+
pred_score,
|
| 200 |
+
pred_mask,
|
| 201 |
+
class_names,
|
| 202 |
+
title="",
|
| 203 |
+
ax=None,
|
| 204 |
+
show_mask=True,
|
| 205 |
+
show_box=True,
|
| 206 |
+
iou_threshold=0.5,
|
| 207 |
+
score_threshold=0.5,
|
| 208 |
+
):
|
| 209 |
+
"""Display ground truth and prediction instances on the same image."""
|
| 210 |
+
# Match predictions to ground truth
|
| 211 |
+
gt_match, pred_match, overlaps = utils.compute_matches(
|
| 212 |
+
gt_box,
|
| 213 |
+
gt_class_id,
|
| 214 |
+
gt_mask,
|
| 215 |
+
pred_box,
|
| 216 |
+
pred_class_id,
|
| 217 |
+
pred_score,
|
| 218 |
+
pred_mask,
|
| 219 |
+
iou_threshold=iou_threshold,
|
| 220 |
+
score_threshold=score_threshold,
|
| 221 |
+
)
|
| 222 |
+
# Ground truth = green. Predictions = red
|
| 223 |
+
colors = [(0, 1, 0, 0.8)] * len(gt_match) + [(1, 0, 0, 1)] * len(pred_match)
|
| 224 |
+
# Concatenate GT and predictions
|
| 225 |
+
class_ids = np.concatenate([gt_class_id, pred_class_id])
|
| 226 |
+
scores = np.concatenate([np.zeros([len(gt_match)]), pred_score])
|
| 227 |
+
boxes = np.concatenate([gt_box, pred_box])
|
| 228 |
+
masks = np.concatenate([gt_mask, pred_mask], axis=-1)
|
| 229 |
+
# Captions per instance show score/IoU
|
| 230 |
+
captions = ["" for m in gt_match] + [
|
| 231 |
+
"{:.2f} / {:.2f}".format(
|
| 232 |
+
pred_score[i],
|
| 233 |
+
(
|
| 234 |
+
overlaps[i, int(pred_match[i])]
|
| 235 |
+
if pred_match[i] > -1
|
| 236 |
+
else overlaps[i].max()
|
| 237 |
+
),
|
| 238 |
+
)
|
| 239 |
+
for i in range(len(pred_match))
|
| 240 |
+
]
|
| 241 |
+
# Set title if not provided
|
| 242 |
+
title = (
|
| 243 |
+
title or "Ground Truth and Detections\n GT=green, pred=red, captions: score/IoU"
|
| 244 |
+
)
|
| 245 |
+
# Display
|
| 246 |
+
display_instances(
|
| 247 |
+
image,
|
| 248 |
+
boxes,
|
| 249 |
+
masks,
|
| 250 |
+
class_ids,
|
| 251 |
+
class_names,
|
| 252 |
+
scores,
|
| 253 |
+
ax=ax,
|
| 254 |
+
show_bbox=show_box,
|
| 255 |
+
show_mask=show_mask,
|
| 256 |
+
colors=colors,
|
| 257 |
+
captions=captions,
|
| 258 |
+
title=title,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10):
|
| 263 |
+
"""
|
| 264 |
+
anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates.
|
| 265 |
+
proposals: [n, 4] the same anchors but refined to fit objects better.
|
| 266 |
+
"""
|
| 267 |
+
masked_image = image.copy()
|
| 268 |
+
|
| 269 |
+
# Pick random anchors in case there are too many.
|
| 270 |
+
ids = np.arange(rois.shape[0], dtype=np.int32)
|
| 271 |
+
ids = np.random.choice(ids, limit, replace=False) if ids.shape[0] > limit else ids
|
| 272 |
+
|
| 273 |
+
fig, ax = plt.subplots(1, figsize=(12, 12))
|
| 274 |
+
if rois.shape[0] > limit:
|
| 275 |
+
plt.title("Showing {} random ROIs out of {}".format(len(ids), rois.shape[0]))
|
| 276 |
+
else:
|
| 277 |
+
plt.title("{} ROIs".format(len(ids)))
|
| 278 |
+
|
| 279 |
+
# Show area outside image boundaries.
|
| 280 |
+
ax.set_ylim(image.shape[0] + 20, -20)
|
| 281 |
+
ax.set_xlim(-50, image.shape[1] + 20)
|
| 282 |
+
ax.axis("off")
|
| 283 |
+
|
| 284 |
+
for i, id in enumerate(ids):
|
| 285 |
+
color = np.random.rand(3)
|
| 286 |
+
class_id = class_ids[id]
|
| 287 |
+
# ROI
|
| 288 |
+
y1, x1, y2, x2 = rois[id]
|
| 289 |
+
p = patches.Rectangle(
|
| 290 |
+
(x1, y1),
|
| 291 |
+
x2 - x1,
|
| 292 |
+
y2 - y1,
|
| 293 |
+
linewidth=2,
|
| 294 |
+
edgecolor=color if class_id else "gray",
|
| 295 |
+
facecolor="none",
|
| 296 |
+
linestyle="dashed",
|
| 297 |
+
)
|
| 298 |
+
ax.add_patch(p)
|
| 299 |
+
# Refined ROI
|
| 300 |
+
if class_id:
|
| 301 |
+
ry1, rx1, ry2, rx2 = refined_rois[id]
|
| 302 |
+
p = patches.Rectangle(
|
| 303 |
+
(rx1, ry1),
|
| 304 |
+
rx2 - rx1,
|
| 305 |
+
ry2 - ry1,
|
| 306 |
+
linewidth=2,
|
| 307 |
+
edgecolor=color,
|
| 308 |
+
facecolor="none",
|
| 309 |
+
)
|
| 310 |
+
ax.add_patch(p)
|
| 311 |
+
# Connect the top-left corners of the anchor and proposal for easy visualization
|
| 312 |
+
ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color))
|
| 313 |
+
|
| 314 |
+
# Label
|
| 315 |
+
label = class_names[class_id]
|
| 316 |
+
ax.text(
|
| 317 |
+
rx1,
|
| 318 |
+
ry1 + 8,
|
| 319 |
+
"{}".format(label),
|
| 320 |
+
color="w",
|
| 321 |
+
size=11,
|
| 322 |
+
backgroundcolor="none",
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Mask
|
| 326 |
+
m = utils.unmold_mask(mask[id], rois[id][:4].astype(np.int32), image.shape)
|
| 327 |
+
masked_image = apply_mask(masked_image, m, color)
|
| 328 |
+
|
| 329 |
+
ax.imshow(masked_image)
|
| 330 |
+
|
| 331 |
+
# Print stats
|
| 332 |
+
print("Positive ROIs: ", class_ids[class_ids > 0].shape[0])
|
| 333 |
+
print("Negative ROIs: ", class_ids[class_ids == 0].shape[0])
|
| 334 |
+
print(
|
| 335 |
+
"Positive Ratio: {:.2f}".format(
|
| 336 |
+
class_ids[class_ids > 0].shape[0] / class_ids.shape[0]
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# TODO: Replace with matplotlib equivalent?
|
| 342 |
+
def draw_box(image, box, color):
|
| 343 |
+
"""Draw 3-pixel width bounding boxes on the given image array.
|
| 344 |
+
color: list of 3 int values for RGB.
|
| 345 |
+
"""
|
| 346 |
+
y1, x1, y2, x2 = box
|
| 347 |
+
image[y1 : y1 + 2, x1:x2] = color
|
| 348 |
+
image[y2 : y2 + 2, x1:x2] = color
|
| 349 |
+
image[y1:y2, x1 : x1 + 2] = color
|
| 350 |
+
image[y1:y2, x2 : x2 + 2] = color
|
| 351 |
+
return image
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def display_top_masks(image, mask, class_ids, class_names, limit=4):
|
| 355 |
+
"""Display the given image and the top few class masks."""
|
| 356 |
+
to_display = []
|
| 357 |
+
titles = []
|
| 358 |
+
to_display.append(image)
|
| 359 |
+
titles.append("H x W={}x{}".format(image.shape[0], image.shape[1]))
|
| 360 |
+
# Pick top prominent classes in this image
|
| 361 |
+
unique_class_ids = np.unique(class_ids)
|
| 362 |
+
mask_area = [
|
| 363 |
+
np.sum(mask[:, :, np.where(class_ids == i)[0]]) for i in unique_class_ids
|
| 364 |
+
]
|
| 365 |
+
top_ids = [
|
| 366 |
+
v[0]
|
| 367 |
+
for v in sorted(
|
| 368 |
+
zip(unique_class_ids, mask_area), key=lambda r: r[1], reverse=True
|
| 369 |
+
)
|
| 370 |
+
if v[1] > 0
|
| 371 |
+
]
|
| 372 |
+
# Generate images and titles
|
| 373 |
+
for i in range(limit):
|
| 374 |
+
class_id = top_ids[i] if i < len(top_ids) else -1
|
| 375 |
+
# Pull masks of instances belonging to the same class.
|
| 376 |
+
m = mask[:, :, np.where(class_ids == class_id)[0]]
|
| 377 |
+
m = np.sum(m * np.arange(1, m.shape[-1] + 1), -1)
|
| 378 |
+
to_display.append(m)
|
| 379 |
+
titles.append(class_names[class_id] if class_id != -1 else "-")
|
| 380 |
+
display_images(to_display, titles=titles, cols=limit + 1, cmap="Blues_r")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def plot_precision_recall(AP, precisions, recalls):
|
| 384 |
+
"""Draw the precision-recall curve.
|
| 385 |
+
|
| 386 |
+
AP: Average precision at IoU >= 0.5
|
| 387 |
+
precisions: list of precision values
|
| 388 |
+
recalls: list of recall values
|
| 389 |
+
"""
|
| 390 |
+
# Plot the Precision-Recall curve
|
| 391 |
+
_, ax = plt.subplots(1)
|
| 392 |
+
ax.set_title("Precision-Recall Curve. AP@50 = {:.3f}".format(AP))
|
| 393 |
+
ax.set_ylim(0, 1.1)
|
| 394 |
+
ax.set_xlim(0, 1.1)
|
| 395 |
+
_ = ax.plot(recalls, precisions)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def plot_overlaps(
|
| 399 |
+
gt_class_ids, pred_class_ids, pred_scores, overlaps, class_names, threshold=0.5
|
| 400 |
+
):
|
| 401 |
+
"""Draw a grid showing how ground truth objects are classified.
|
| 402 |
+
gt_class_ids: [N] int. Ground truth class IDs
|
| 403 |
+
pred_class_id: [N] int. Predicted class IDs
|
| 404 |
+
pred_scores: [N] float. The probability scores of predicted classes
|
| 405 |
+
overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes.
|
| 406 |
+
class_names: list of all class names in the dataset
|
| 407 |
+
threshold: Float. The prediction probability required to predict a class
|
| 408 |
+
"""
|
| 409 |
+
gt_class_ids = gt_class_ids[gt_class_ids != 0]
|
| 410 |
+
pred_class_ids = pred_class_ids[pred_class_ids != 0]
|
| 411 |
+
|
| 412 |
+
plt.figure(figsize=(12, 10))
|
| 413 |
+
plt.imshow(overlaps, interpolation="nearest", cmap=plt.cm.Blues)
|
| 414 |
+
plt.yticks(
|
| 415 |
+
np.arange(len(pred_class_ids)),
|
| 416 |
+
[
|
| 417 |
+
"{} ({:.2f})".format(class_names[int(id)], pred_scores[i])
|
| 418 |
+
for i, id in enumerate(pred_class_ids)
|
| 419 |
+
],
|
| 420 |
+
)
|
| 421 |
+
plt.xticks(
|
| 422 |
+
np.arange(len(gt_class_ids)),
|
| 423 |
+
[class_names[int(id)] for id in gt_class_ids],
|
| 424 |
+
rotation=90,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
thresh = overlaps.max() / 2.0
|
| 428 |
+
for i, j in itertools.product(range(overlaps.shape[0]), range(overlaps.shape[1])):
|
| 429 |
+
text = ""
|
| 430 |
+
if overlaps[i, j] > threshold:
|
| 431 |
+
text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong"
|
| 432 |
+
color = (
|
| 433 |
+
"white"
|
| 434 |
+
if overlaps[i, j] > thresh
|
| 435 |
+
else "black"
|
| 436 |
+
if overlaps[i, j] > 0
|
| 437 |
+
else "grey"
|
| 438 |
+
)
|
| 439 |
+
plt.text(
|
| 440 |
+
j,
|
| 441 |
+
i,
|
| 442 |
+
"{:.3f}\n{}".format(overlaps[i, j], text),
|
| 443 |
+
horizontalalignment="center",
|
| 444 |
+
verticalalignment="center",
|
| 445 |
+
fontsize=9,
|
| 446 |
+
color=color,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
plt.tight_layout()
|
| 450 |
+
plt.xlabel("Ground Truth")
|
| 451 |
+
plt.ylabel("Predictions")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def draw_boxes(
|
| 455 |
+
image,
|
| 456 |
+
boxes=None,
|
| 457 |
+
refined_boxes=None,
|
| 458 |
+
masks=None,
|
| 459 |
+
captions=None,
|
| 460 |
+
visibilities=None,
|
| 461 |
+
title="",
|
| 462 |
+
ax=None,
|
| 463 |
+
):
|
| 464 |
+
"""Draw bounding boxes and segmentation masks with different
|
| 465 |
+
customizations.
|
| 466 |
+
|
| 467 |
+
boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates.
|
| 468 |
+
refined_boxes: Like boxes, but draw with solid lines to show
|
| 469 |
+
that they're the result of refining 'boxes'.
|
| 470 |
+
masks: [N, height, width]
|
| 471 |
+
captions: List of N titles to display on each box
|
| 472 |
+
visibilities: (optional) List of values of 0, 1, or 2. Determine how
|
| 473 |
+
prominent each bounding box should be.
|
| 474 |
+
title: An optional title to show over the image
|
| 475 |
+
ax: (optional) Matplotlib axis to draw on.
|
| 476 |
+
"""
|
| 477 |
+
# Number of boxes
|
| 478 |
+
assert boxes is not None or refined_boxes is not None
|
| 479 |
+
N = boxes.shape[0] if boxes is not None else refined_boxes.shape[0]
|
| 480 |
+
|
| 481 |
+
# Matplotlib Axis
|
| 482 |
+
if not ax:
|
| 483 |
+
_, ax = plt.subplots(1, figsize=(12, 12))
|
| 484 |
+
|
| 485 |
+
# Generate random colors
|
| 486 |
+
colors = random_colors(N)
|
| 487 |
+
|
| 488 |
+
# Show area outside image boundaries.
|
| 489 |
+
margin = image.shape[0] // 10
|
| 490 |
+
ax.set_ylim(image.shape[0] + margin, -margin)
|
| 491 |
+
ax.set_xlim(-margin, image.shape[1] + margin)
|
| 492 |
+
ax.axis("off")
|
| 493 |
+
|
| 494 |
+
ax.set_title(title)
|
| 495 |
+
|
| 496 |
+
masked_image = image.astype(np.uint32).copy()
|
| 497 |
+
for i in range(N):
|
| 498 |
+
# Box visibility
|
| 499 |
+
visibility = visibilities[i] if visibilities is not None else 1
|
| 500 |
+
if visibility == 0:
|
| 501 |
+
color = "gray"
|
| 502 |
+
style = "dotted"
|
| 503 |
+
alpha = 0.5
|
| 504 |
+
elif visibility == 1:
|
| 505 |
+
color = colors[i]
|
| 506 |
+
style = "dotted"
|
| 507 |
+
alpha = 1
|
| 508 |
+
elif visibility == 2:
|
| 509 |
+
color = colors[i]
|
| 510 |
+
style = "solid"
|
| 511 |
+
alpha = 1
|
| 512 |
+
|
| 513 |
+
# Boxes
|
| 514 |
+
if boxes is not None:
|
| 515 |
+
if not np.any(boxes[i]):
|
| 516 |
+
# Skip this instance. Has no bbox. Likely lost in cropping.
|
| 517 |
+
continue
|
| 518 |
+
y1, x1, y2, x2 = boxes[i]
|
| 519 |
+
p = patches.Rectangle(
|
| 520 |
+
(x1, y1),
|
| 521 |
+
x2 - x1,
|
| 522 |
+
y2 - y1,
|
| 523 |
+
linewidth=2,
|
| 524 |
+
alpha=alpha,
|
| 525 |
+
linestyle=style,
|
| 526 |
+
edgecolor=color,
|
| 527 |
+
facecolor="none",
|
| 528 |
+
)
|
| 529 |
+
ax.add_patch(p)
|
| 530 |
+
|
| 531 |
+
# Refined boxes
|
| 532 |
+
if refined_boxes is not None and visibility > 0:
|
| 533 |
+
ry1, rx1, ry2, rx2 = refined_boxes[i].astype(np.int32)
|
| 534 |
+
p = patches.Rectangle(
|
| 535 |
+
(rx1, ry1),
|
| 536 |
+
rx2 - rx1,
|
| 537 |
+
ry2 - ry1,
|
| 538 |
+
linewidth=2,
|
| 539 |
+
edgecolor=color,
|
| 540 |
+
facecolor="none",
|
| 541 |
+
)
|
| 542 |
+
ax.add_patch(p)
|
| 543 |
+
# Connect the top-left corners of the anchor and proposal
|
| 544 |
+
if boxes is not None:
|
| 545 |
+
ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color))
|
| 546 |
+
|
| 547 |
+
# Captions
|
| 548 |
+
if captions is not None:
|
| 549 |
+
caption = captions[i]
|
| 550 |
+
# If there are refined boxes, display captions on them
|
| 551 |
+
if refined_boxes is not None:
|
| 552 |
+
y1, x1, y2, x2 = ry1, rx1, ry2, rx2
|
| 553 |
+
ax.text(
|
| 554 |
+
x1,
|
| 555 |
+
y1,
|
| 556 |
+
caption,
|
| 557 |
+
size=11,
|
| 558 |
+
verticalalignment="top",
|
| 559 |
+
color="w",
|
| 560 |
+
backgroundcolor="none",
|
| 561 |
+
bbox={"facecolor": color, "alpha": 0.5, "pad": 2, "edgecolor": "none"},
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Masks
|
| 565 |
+
if masks is not None:
|
| 566 |
+
mask = masks[:, :, i]
|
| 567 |
+
masked_image = apply_mask(masked_image, mask, color)
|
| 568 |
+
# Mask Polygon
|
| 569 |
+
# Pad to ensure proper polygons for masks that touch image edges.
|
| 570 |
+
padded_mask = np.zeros(
|
| 571 |
+
(mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8
|
| 572 |
+
)
|
| 573 |
+
padded_mask[1:-1, 1:-1] = mask
|
| 574 |
+
contours = find_contours(padded_mask, 0.5)
|
| 575 |
+
for verts in contours:
|
| 576 |
+
# Subtract the padding and flip (y, x) to (x, y)
|
| 577 |
+
verts = np.fliplr(verts) - 1
|
| 578 |
+
p = Polygon(verts, facecolor="none", edgecolor=color)
|
| 579 |
+
ax.add_patch(p)
|
| 580 |
+
ax.imshow(masked_image.astype(np.uint8))
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def display_table(table):
|
| 584 |
+
"""Display values in a table format.
|
| 585 |
+
table: an iterable of rows, and each row is an iterable of values.
|
| 586 |
+
"""
|
| 587 |
+
html = ""
|
| 588 |
+
for row in table:
|
| 589 |
+
row_html = ""
|
| 590 |
+
for col in row:
|
| 591 |
+
row_html += "<td>{:40}</td>".format(str(col))
|
| 592 |
+
html += "<tr>" + row_html + "</tr>"
|
| 593 |
+
html = "<table>" + html + "</table>"
|
| 594 |
+
IPython.display.display(IPython.display.HTML(html))
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def display_weight_stats(model):
|
| 598 |
+
"""Scans all the weights in the model and returns a list of tuples
|
| 599 |
+
that contain stats about each weight.
|
| 600 |
+
"""
|
| 601 |
+
layers = model.get_trainable_layers()
|
| 602 |
+
table = [["WEIGHT NAME", "SHAPE", "MIN", "MAX", "STD"]]
|
| 603 |
+
for l in layers:
|
| 604 |
+
weight_values = l.get_weights() # list of Numpy arrays
|
| 605 |
+
weight_tensors = l.weights # list of TF tensors
|
| 606 |
+
for i, w in enumerate(weight_values):
|
| 607 |
+
weight_name = weight_tensors[i].name
|
| 608 |
+
# Detect problematic layers. Exclude biases of conv layers.
|
| 609 |
+
alert = ""
|
| 610 |
+
if w.min() == w.max() and not (l.__class__.__name__ == "Conv2D" and i == 1):
|
| 611 |
+
alert += "<span style='color:red'>*** dead?</span>"
|
| 612 |
+
if np.abs(w.min()) > 1000 or np.abs(w.max()) > 1000:
|
| 613 |
+
alert += "<span style='color:red'>*** Overflow?</span>"
|
| 614 |
+
# Add row
|
| 615 |
+
table.append(
|
| 616 |
+
[
|
| 617 |
+
weight_name + alert,
|
| 618 |
+
str(w.shape),
|
| 619 |
+
"{:+9.4f}".format(w.min()),
|
| 620 |
+
"{:+10.4f}".format(w.max()),
|
| 621 |
+
"{:+9.4f}".format(w.std()),
|
| 622 |
+
]
|
| 623 |
+
)
|
| 624 |
+
display_table(table)
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow==1.14.0
|
| 2 |
+
keras==2.0.8
|
| 3 |
+
protobuf==3.20.1
|
| 4 |
+
gradio==3.0.15
|
| 5 |
+
gdown==4.4.0
|
| 6 |
+
numpy
|
| 7 |
+
scipy
|
| 8 |
+
Pillow
|
| 9 |
+
cython
|
| 10 |
+
matplotlib
|
| 11 |
+
scikit-image
|
| 12 |
+
opencv-python
|
| 13 |
+
h5py==2.10.0
|
| 14 |
+
imgaug
|
| 15 |
+
IPython[all]
|
setup.cfg
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[metadata]
|
| 2 |
+
description-file = README.md
|
| 3 |
+
license-file = LICENSE
|
| 4 |
+
requirements-file = requirements.txt
|
setup.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The build/compilations setup
|
| 3 |
+
|
| 4 |
+
>> pip install -r requirements.txt
|
| 5 |
+
>> python setup.py install
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
import pip
|
| 10 |
+
import pkg_resources
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from setuptools import setup
|
| 14 |
+
except ImportError:
|
| 15 |
+
from distutils.core import setup
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _parse_requirements(file_path):
|
| 19 |
+
pip_ver = pkg_resources.get_distribution("pip").version
|
| 20 |
+
pip_version = list(map(int, pip_ver.split(".")[:2]))
|
| 21 |
+
if pip_version >= [6, 0]:
|
| 22 |
+
raw = pip.req.parse_requirements(file_path, session=pip.download.PipSession())
|
| 23 |
+
else:
|
| 24 |
+
raw = pip.req.parse_requirements(file_path)
|
| 25 |
+
return [str(i.req) for i in raw]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# parse_requirements() returns generator of pip.req.InstallRequirement objects
|
| 29 |
+
try:
|
| 30 |
+
install_reqs = _parse_requirements("requirements.txt")
|
| 31 |
+
except Exception:
|
| 32 |
+
logging.warning("Fail load requirements file, so using default ones.")
|
| 33 |
+
install_reqs = []
|
| 34 |
+
|
| 35 |
+
setup(
|
| 36 |
+
name="mask-rcnn",
|
| 37 |
+
version="2.1",
|
| 38 |
+
url="https://github.com/matterport/Mask_RCNN",
|
| 39 |
+
author="Matterport",
|
| 40 |
+
author_email="[email protected]",
|
| 41 |
+
license="MIT",
|
| 42 |
+
description="Mask R-CNN for object detection and instance segmentation",
|
| 43 |
+
packages=["mrcnn"],
|
| 44 |
+
install_requires=install_reqs,
|
| 45 |
+
include_package_data=True,
|
| 46 |
+
python_requires=">=3.4",
|
| 47 |
+
long_description="""This is an implementation of Mask R-CNN on Python 3, Keras, and TensorFlow.
|
| 48 |
+
The model generates bounding boxes and segmentation masks for each instance of an object in the image.
|
| 49 |
+
It's based on Feature Pyramid Network (FPN) and a ResNet101 backbone.""",
|
| 50 |
+
classifiers=[
|
| 51 |
+
"Development Status :: 5 - Production/Stable",
|
| 52 |
+
"Environment :: Console",
|
| 53 |
+
"Intended Audience :: Developers",
|
| 54 |
+
"Intended Audience :: Information Technology",
|
| 55 |
+
"Intended Audience :: Education",
|
| 56 |
+
"Intended Audience :: Science/Research",
|
| 57 |
+
"License :: OSI Approved :: MIT License",
|
| 58 |
+
"Natural Language :: English",
|
| 59 |
+
"Operating System :: OS Independent",
|
| 60 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 61 |
+
"Topic :: Scientific/Engineering :: Image Recognition",
|
| 62 |
+
"Topic :: Scientific/Engineering :: Visualization",
|
| 63 |
+
"Topic :: Scientific/Engineering :: Image Segmentation",
|
| 64 |
+
"Programming Language :: Python :: 3.4",
|
| 65 |
+
"Programming Language :: Python :: 3.5",
|
| 66 |
+
"Programming Language :: Python :: 3.6",
|
| 67 |
+
],
|
| 68 |
+
keywords="image instance segmentation object detection mask rcnn r-cnn tensorflow keras",
|
| 69 |
+
)
|
utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# from official repo
|
| 5 |
+
def get_ax(rows=1, cols=1, size=7):
|
| 6 |
+
"""Return a Matplotlib Axes array to be used in
|
| 7 |
+
all visualizations in the notebook. Provide a
|
| 8 |
+
central point to control graph sizes.
|
| 9 |
+
|
| 10 |
+
Adjust the size attribute to control how big to render images
|
| 11 |
+
"""
|
| 12 |
+
_, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
|
| 13 |
+
return ax
|