Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
from sahi.utils.yolov5 import download_yolov5s6_model | |
# import required functions, classes | |
from sahi import AutoDetectionModel | |
from sahi.predict import get_sliced_prediction, visualize_object_predictions | |
# Autodetect GPU | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# Load the model | |
yolov5_model_path = "best.pt" | |
download_yolov5s6_model(destination_path=yolov5_model_path) | |
detection_model = AutoDetectionModel.from_pretrained( | |
model_type="yolov5", | |
model_path=yolov5_model_path, | |
confidence_threshold=0.01, | |
device=device, | |
) | |
def do_detection(image_path, hide_labels, confidence_scores): | |
# Obtain detection results | |
result = get_sliced_prediction( | |
image_path, | |
detection_model, | |
slice_height=512, | |
slice_width=512, | |
overlap_height_ratio=0.12, | |
overlap_width_ratio=0.12, | |
) | |
# Filter detections according to the slider and count the number of classes | |
# for visualization | |
predictions = [] | |
class_counts = {} | |
for i in result.object_prediction_list: | |
score = i.score | |
value = score.value | |
category = i.category | |
category_name = category.name | |
if value > confidence_scores[category_name]: | |
predictions.append(i) | |
if i.category.name not in class_counts: | |
class_counts[i.category.name] = 1 | |
else: | |
class_counts[i.category.name] += 1 | |
# Draw the boxes and labels on top of the image | |
img_rgb = visualize_object_predictions( | |
image_path, | |
object_prediction_list=predictions, | |
text_size=1, | |
text_th=1, | |
hide_labels=hide_labels, | |
rect_th=3, | |
)["image"] | |
# Construct a legend | |
legend_text = "Symbols Counted:" | |
for class_name, count in class_counts.items(): | |
legend_text += f" {class_name}: {count} |" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
if hide_labels: | |
font_scale = 1.5 | |
else: | |
font_scale = 1 | |
font_color = (255, 255, 255) | |
font_thickness = 2 | |
legend_bg_color = (131, 79, 0) | |
legend_padding = 10 | |
legend_size, _ = cv2.getTextSize(legend_text, font, font_scale, font_thickness) | |
legend_bg_height = legend_size[1] + 2 * legend_padding | |
legend_bg_width = legend_size[0] + 2 * legend_padding | |
legend_bg = np.zeros((legend_bg_height, legend_bg_width, 3), dtype=np.uint8) | |
legend_bg[:] = legend_bg_color | |
cv2.putText( | |
legend_bg, | |
legend_text, | |
(legend_padding, legend_padding + legend_size[1]), | |
font, | |
font_scale, | |
font_color, | |
font_thickness, | |
) | |
img_height, img_width, _ = img_rgb.shape | |
legend_x = img_width - legend_bg_width | |
legend_y = img_height - legend_bg_height | |
img_rgb[legend_y:, legend_x:, :] = legend_bg | |
return ( | |
img_rgb, | |
result.to_coco_predictions(), | |
) | |
def call_func( | |
image_path, | |
hide_labels, | |
singleplex_value, | |
duplex_value, | |
triplex_value, | |
quadruplex_value, | |
gfci_value, | |
gfci_wp_value, | |
): | |
confidence_scores = { | |
"Singleplex - Standard": singleplex_value, | |
"Duplex - Standard": duplex_value, | |
"Triplex - Standard": triplex_value, | |
"Quadruplex - Standard": quadruplex_value, | |
"Duplex - GFCI": gfci_value, | |
"Duplex - Weatherproof-GFCI": gfci_wp_value, | |
} | |
return do_detection(image_path, hide_labels, confidence_scores) | |
demo = gr.Blocks() | |
theme = gr.themes.Soft() | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown( | |
""" | |
<h1 align="center">Receptacle Detector for Takeoff Automation</h1> | |
""" | |
) | |
with gr.Row(): | |
input_image = gr.Image( | |
label="Upload an image here.", | |
source="upload", | |
interactive=True, | |
) | |
examples = gr.Examples( | |
examples=[ | |
["test1.jpg"], | |
["test2.jpg"], | |
["test3.jpg"], | |
["test4.jpg"], | |
], | |
inputs=[input_image], | |
examples_per_page=4, | |
label="Examples to use.", | |
) | |
hide_labels = gr.Checkbox(label="Hide labels") | |
with gr.Accordion("Visualization Confidence Thresholds", open=False): | |
singleplex_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.53, | |
interactive=True, | |
label="Singleplex", | |
) | |
duplex_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.66, | |
interactive=True, | |
label="Duplex", | |
) | |
triplex_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.65, | |
interactive=True, | |
label="Triplex", | |
) | |
quadruplex_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.63, | |
interactive=True, | |
label="Quadruplex", | |
) | |
gfci_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.31, | |
interactive=True, | |
label="GFCI", | |
) | |
gfci_wp_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.33, | |
interactive=True, | |
label="GFCI/WP", | |
) | |
results_button = gr.Button("Submit") | |
results_button.click( | |
call_func, | |
inputs=[ | |
input_image, | |
hide_labels, | |
singleplex_slider, | |
duplex_slider, | |
triplex_slider, | |
quadruplex_slider, | |
gfci_slider, | |
gfci_wp_slider, | |
], | |
outputs=[ | |
gr.Image(type="numpy", label="Output Image"), | |
gr.Json(), | |
], | |
) | |
demo.launch() |