|
import os |
|
import socket |
|
import time |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import base64 |
|
import requests |
|
import json |
|
|
|
|
|
DL4EO_API_URL = "https://dl4eo--ship-predict.modal.run" |
|
|
|
|
|
DL4EO_API_KEY = 'dprY8HYkE9iXeCS4JnGjch5B' |
|
|
|
|
|
LINE_WIDTH = 2 |
|
|
|
|
|
print(f"Gradio version: {gr.__version__}") |
|
|
|
|
|
def predict_image(image, threshold): |
|
|
|
if not isinstance(image, Image.Image): |
|
raise BaseException("predit_image(): input 'image' shoud be single RGB image in PIL format.") |
|
|
|
img = np.array(image) |
|
if len(img.shape) != 3 or img.shape[2] != 3: |
|
raise BaseException("predit_image(): input 'image' shoud be single RGB image in PIL format.") |
|
|
|
|
|
image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode() |
|
|
|
|
|
payload = { |
|
'image': image_base64, |
|
'shape': img.shape, |
|
'threshold': threshold, |
|
} |
|
|
|
headers = { |
|
'Authorization': 'Bearer ' + DL4EO_API_KEY, |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
|
|
response = requests.post(DL4EO_API_URL, json=payload, headers=headers) |
|
|
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Received status code={response.status_code} in inference API" |
|
) |
|
|
|
json_data = json.loads(response.content) |
|
detections = json_data['detections'] |
|
duration = json_data['duration'] |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
font = ImageFont.truetype("coolvetica_condensed_rg.otf", 24) |
|
|
|
for detection in detections: |
|
coords = detection['xyxyxyxy'] |
|
if len(coords) != 4: |
|
raise ValueError("Each detection should be a polygon with 4 coordinates (xyxyxyxy).") |
|
|
|
points = [(coord[0], coord[1]) for coord in coords] |
|
draw.polygon(points, outline="white", width=LINE_WIDTH) |
|
|
|
|
|
min_x = min(point[0] for point in points) |
|
max_x = max(point[0] for point in points) |
|
min_y = min(point[1] for point in points) |
|
max_y = max(point[1] for point in points) |
|
|
|
text_width, text_height = draw.textbbox((0, 0), detection['class_name'], font=font)[2:] |
|
text_x = (min_x + max_x) / 2 - text_width / 2 |
|
draw.text((text_x, min_y - text_height - LINE_WIDTH), detection['class_name'] + ' | ' + str(round(detection['confidence'], 3)), fill="white", font=font) |
|
|
|
return image, img.shape, len(detections), duration |
|
|
|
|
|
|
|
example_data = [ |
|
["./demo/12ab97857.jpg", 0.6], |
|
["./demo/82f13510a.jpg", 0.6], |
|
["./demo/836f35381.jpg", 0.6], |
|
["./demo/848d2afef.jpg", 0.6], |
|
["./demo/911b25478.jpg", 0.6], |
|
["./demo/b86e4046f.jpg", 0.6], |
|
["./demo/ce2220f49.jpg", 0.6], |
|
["./demo/d9762ef5e.jpg", 0.6], |
|
["./demo/fa613751e.jpg", 0.6], |
|
|
|
|
|
] |
|
|
|
|
|
css = """ |
|
.image-preview { |
|
height: 768px !important; |
|
width: 768px !important; |
|
} |
|
""" |
|
|
|
TITLE = "Ship detection on SPOT satellite images (Oriented Bounding Boxes)" |
|
|
|
|
|
demo = gr.Blocks(title=TITLE, css=css).queue() |
|
with demo: |
|
gr.Markdown(f"<h1><center>{TITLE}<center><h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0): |
|
input_image = gr.Image(type="pil", interactive=True) |
|
run_button = gr.Button(value="Run") |
|
with gr.Accordion("Advanced options", open=True): |
|
threshold = gr.Slider(label="Confidence threshold", minimum=0.0, maximum=1.0, value=0.60, step=0.01) |
|
dimensions = gr.Textbox(label="Image size", interactive=False) |
|
detections = gr.Textbox(label="Predicted objects", interactive=False) |
|
stopwatch = gr.Number(label="Execution time (sec.)", interactive=False, precision=3) |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image(type="pil", elem_classes='image-preview', interactive=False) |
|
|
|
run_button.click(fn=predict_image, inputs=[input_image, threshold], outputs=[output_image, dimensions, detections, stopwatch]) |
|
gr.Examples( |
|
examples=example_data, |
|
inputs = [input_image, threshold], |
|
outputs = [output_image, dimensions, detections, stopwatch], |
|
fn=predict_image, |
|
cache_examples=True, |
|
label='Try these images!' |
|
) |
|
|
|
gr.Markdown(""" |
|
<p>This demo is provided by <a href='https://www.linkedin.com/in/faudi/'>Jeff Faudi</a> and <a href='https://www.dl4eo.com/'>DL4EO</a>. |
|
This model is based on the <a href='https://www.ultralytics.com/yolo'>Ultralytics YOLOv8-OBB</a> framework which provides oriented bounding boxes. |
|
We believe that oriented bouding boxes are better suited for detection of ships in satellite images. This model has been trained on the |
|
<a href='https://www.kaggle.com/c/airbus-ship-detection/data'>Airbus Ship Detection dataset</a> available on Kaggle which provide SPOT extracts at 1.5 m. |
|
provided by Airbus DS. The associated license is <a href='https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en'>CC-BY-SA-NC</a>.</p> |
|
<p>This demonstration CANNOT be used for commercial puposes. Please contact <a href='mailto:[email protected]'>me</a> for more information on |
|
how you could get access to a commercial grade model or API. </p> |
|
""") |
|
|
|
demo.launch( |
|
inline=False, |
|
show_api=False, |
|
debug=False |
|
) |
|
|