rbanfield's picture
Continue the rework... now mostly passable
1ee9349
import torch
import gradio as gr
import cv2
import numpy as np
from sahi.utils.yolov5 import download_yolov5s6_model
# import required functions, classes
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, visualize_object_predictions
# Autodetect GPU
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load the model
yolov5_model_path = "best.pt"
download_yolov5s6_model(destination_path=yolov5_model_path)
detection_model = AutoDetectionModel.from_pretrained(
model_type="yolov5",
model_path=yolov5_model_path,
confidence_threshold=0.01,
device=device,
)
def do_detection(image_path, hide_labels, confidence_scores):
# Obtain detection results
result = get_sliced_prediction(
image_path,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.12,
overlap_width_ratio=0.12,
)
# Filter detections according to the slider and count the number of classes
# for visualization
predictions = []
class_counts = {}
for i in result.object_prediction_list:
score = i.score
value = score.value
category = i.category
category_name = category.name
if value > confidence_scores[category_name]:
predictions.append(i)
if i.category.name not in class_counts:
class_counts[i.category.name] = 1
else:
class_counts[i.category.name] += 1
# Draw the boxes and labels on top of the image
img_rgb = visualize_object_predictions(
image_path,
object_prediction_list=predictions,
text_size=1,
text_th=1,
hide_labels=hide_labels,
rect_th=3,
)["image"]
# Construct a legend
legend_text = "Symbols Counted:"
for class_name, count in class_counts.items():
legend_text += f" {class_name}: {count} |"
font = cv2.FONT_HERSHEY_SIMPLEX
if hide_labels:
font_scale = 1.5
else:
font_scale = 1
font_color = (255, 255, 255)
font_thickness = 2
legend_bg_color = (131, 79, 0)
legend_padding = 10
legend_size, _ = cv2.getTextSize(legend_text, font, font_scale, font_thickness)
legend_bg_height = legend_size[1] + 2 * legend_padding
legend_bg_width = legend_size[0] + 2 * legend_padding
legend_bg = np.zeros((legend_bg_height, legend_bg_width, 3), dtype=np.uint8)
legend_bg[:] = legend_bg_color
cv2.putText(
legend_bg,
legend_text,
(legend_padding, legend_padding + legend_size[1]),
font,
font_scale,
font_color,
font_thickness,
)
img_height, img_width, _ = img_rgb.shape
legend_x = img_width - legend_bg_width
legend_y = img_height - legend_bg_height
img_rgb[legend_y:, legend_x:, :] = legend_bg
return (
img_rgb,
result.to_coco_predictions(),
)
def call_func(
image_path,
hide_labels,
singleplex_value,
duplex_value,
triplex_value,
quadruplex_value,
gfci_value,
gfci_wp_value,
):
confidence_scores = {
"Singleplex - Standard": singleplex_value,
"Duplex - Standard": duplex_value,
"Triplex - Standard": triplex_value,
"Quadruplex - Standard": quadruplex_value,
"Duplex - GFCI": gfci_value,
"Duplex - Weatherproof-GFCI": gfci_wp_value,
}
return do_detection(image_path, hide_labels, confidence_scores)
demo = gr.Blocks()
theme = gr.themes.Soft()
with gr.Blocks(theme=theme) as demo:
gr.Markdown(
"""
<h1 align="center">Receptacle Detector for Takeoff Automation</h1>
"""
)
with gr.Row():
input_image = gr.Image(
label="Upload an image here.",
source="upload",
interactive=True,
)
examples = gr.Examples(
examples=[
["test1.jpg"],
["test2.jpg"],
["test3.jpg"],
["test4.jpg"],
],
inputs=[input_image],
examples_per_page=4,
label="Examples to use.",
)
hide_labels = gr.Checkbox(label="Hide labels")
with gr.Accordion("Visualization Confidence Thresholds", open=False):
singleplex_slider = gr.Slider(
minimum=0.1,
maximum=1,
value=0.53,
interactive=True,
label="Singleplex",
)
duplex_slider = gr.Slider(
minimum=0.1,
maximum=1,
value=0.66,
interactive=True,
label="Duplex",
)
triplex_slider = gr.Slider(
minimum=0.1,
maximum=1,
value=0.65,
interactive=True,
label="Triplex",
)
quadruplex_slider = gr.Slider(
minimum=0.1,
maximum=1,
value=0.63,
interactive=True,
label="Quadruplex",
)
gfci_slider = gr.Slider(
minimum=0.1,
maximum=1,
value=0.31,
interactive=True,
label="GFCI",
)
gfci_wp_slider = gr.Slider(
minimum=0.1,
maximum=1,
value=0.33,
interactive=True,
label="GFCI/WP",
)
results_button = gr.Button("Submit")
results_button.click(
call_func,
inputs=[
input_image,
hide_labels,
singleplex_slider,
duplex_slider,
triplex_slider,
quadruplex_slider,
gfci_slider,
gfci_wp_slider,
],
outputs=[
gr.Image(type="numpy", label="Output Image"),
gr.Json(),
],
)
demo.launch()