sfoy's picture
Update app.py
bea8391 verified
raw
history blame
2.54 kB
import gradio as gr
from huggingface_hub import hf_hub_download
def download_models(model_id):
model_file_path = hf_hub_download("merve/yolov9", filename=model_id)
return model_file_path
def yolov9_inference(img_path, model_id, image_size, conf_threshold, iou_threshold):
"""
Performs object detection using a YOLOv9 model. This function loads a specified YOLOv9 model,
configures it based on the provided parameters, and carries out inference on a given image.
Additionally, it allows for optional modification of the input size and the application of
test time augmentation to potentially improve detection accuracy.
"""
# Import YOLOv9
import yolov9
# Load the model
model_path = download_models(model_id)
model = yolov9.load(model_path, device="cpu")
# Set model parameters
model.conf = conf_threshold
model.iou = iou_threshold
# Perform inference
results = model(img_path, size=image_size)
# Optionally, show detection bounding boxes on image
output = results.render()
return output[0]
def app():
with gr.Blocks() as blocks:
with gr.Row():
with gr.Column():
img_path = gr.Image(type="filepath", label="Image")
model_id = gr.Dropdown(
label="Model",
choices=["gelan-c.pt", "gelan-e.pt", "yolov9-c.pt", "yolov9-e.pt"],
value="gelan-e.pt"
)
image_size = gr.Slider(label="Image Size", minimum=320, maximum=1280, step=32, value=640)
conf_threshold = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, step=0.1, value=0.4)
iou_threshold = gr.Slider(label="IoU Threshold", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
yolov9_infer = gr.Button("Inference")
with gr.Column():
output_image = gr.Image(type="numpy", label="Output")
yolov9_infer.click(
fn=yolov9_inference,
inputs=[img_path, model_id, image_size, conf_threshold, iou_threshold],
outputs=[output_image]
)
return blocks
gradio_app = app()
# Display a title using HTML, centered.
gradio_app[''].add(
gr.HTML("""
<h1 style='text-align: center; margin-bottom: 20px;'>
YOLOv9 from PipYoloV9 on my data
</h1>
""")
)
# Launch the Gradio app, enabling debug mode for detailed error logs and server information.
gradio_app.launch(debug=True)