sfoy commited on
Commit
16e71ba
·
verified ·
1 Parent(s): 5626e2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -44
app.py CHANGED
@@ -1,64 +1,72 @@
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
- import torch
4
 
5
  def download_models(model_id):
6
- """
7
- Downloads a model file from Hugging Face Hub to a specified local directory.
8
-
9
- Parameters:
10
- - model_id (str): Identifier of the model to download.
11
 
12
- Returns:
13
- - str: Path to the downloaded model file.
14
- """
15
- model_path = hf_hub_download(repo_id="merve/yolov9", filename=model_id)
16
- return model_path
17
-
18
- def yolov9_inference(img_path, model_id="model_weights.pth", image_size=640, conf_threshold=0.25, iou_threshold=0.45):
19
  """
20
  Performs object detection using a YOLOv9 model. This function loads a specified YOLOv9 model,
21
  configures it based on the provided parameters, and carries out inference on a given image.
22
  Additionally, it allows for optional modification of the input size and the application of
23
  test time augmentation to potentially improve detection accuracy.
24
-
25
- Parameters:
26
- - img_path (str): The file path to the image on which inference is to be performed.
27
- - model_id (str): Identifier of the model to use.
28
- - image_size (int): The input size for inference.
29
- - conf_threshold (float): The confidence threshold used during Non-Maximum Suppression.
30
- - iou_threshold (float): The Intersection over Union threshold applied in NMS.
31
-
32
- Returns:
33
- - Image: An image with detection bounding boxes drawn on it.
34
  """
35
- # Import YOLOv9 and torch only when the function is called to save on initial script load time
36
- from yolov9 import YOLOv9
37
- from PIL import Image
38
- import numpy as np
39
-
40
- # Download and load the model
41
  model_path = download_models(model_id)
42
- model = YOLOv9(model_path, conf_threshold=conf_threshold, iou_threshold=iou_threshold, img_size=image_size)
43
- model.eval() # Set the model to evaluation mode
 
 
 
 
 
 
44
 
45
- # Load image
46
- img = Image.open(img_path).convert("RGB")
47
- img = np.array(img)
 
48
 
49
- # Perform inference
50
- results = model.predict(img, size=image_size)
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Extract results and visualize
53
- output_image = model.visualize(results, img)
54
- return output_image
55
 
56
- # Example Gradio interface setup (simplified for demonstration purposes)
57
- def gradio_interface(img_path):
58
- return yolov9_inference(img_path)
 
 
59
 
60
- iface = gr.Interface(fn=gradio_interface, inputs="image", outputs="image", title="YOLOv9 Object Detection")
61
- iface.launch()
62
 
 
63
 
 
 
 
 
 
 
 
 
64
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
 
3
 
4
  def download_models(model_id):
5
+ model_file_path = hf_hub_download("merve/yolov9", filename=model_id)
6
+ return model_file_path
 
 
 
7
 
8
+ def yolov9_inference(img_path, model_id, image_size, conf_threshold, iou_threshold):
 
 
 
 
 
 
9
  """
10
  Performs object detection using a YOLOv9 model. This function loads a specified YOLOv9 model,
11
  configures it based on the provided parameters, and carries out inference on a given image.
12
  Additionally, it allows for optional modification of the input size and the application of
13
  test time augmentation to potentially improve detection accuracy.
 
 
 
 
 
 
 
 
 
 
14
  """
15
+ # Import YOLOv9
16
+ import yolov9
17
+
18
+ # Load the model
 
 
19
  model_path = download_models(model_id)
20
+ model = yolov9.load(model_path, device="cuda:0")
21
+
22
+ # Set model parameters
23
+ model.conf = conf_threshold
24
+ model.iou = iou_threshold
25
+
26
+ # Perform inference
27
+ results = model(img_path, size=image_size)
28
 
29
+ # Optionally, show detection bounding boxes on image
30
+ output = results.render()
31
+
32
+ return output[0]
33
 
34
+ def app():
35
+ with gr.Blocks() as blocks:
36
+ with gr.Row():
37
+ with gr.Column():
38
+ img_path = gr.Image(type="filepath", label="Image")
39
+ model_id = gr.Dropdown(
40
+ label="Model",
41
+ choices=["gelan-c.pt", "gelan-e.pt", "yolov9-c.pt", "yolov9-e.pt"],
42
+ value="gelan-e.pt"
43
+ )
44
+ image_size = gr.Slider(label="Image Size", minimum=320, maximum=1280, step=32, value=640)
45
+ conf_threshold = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, step=0.1, value=0.4)
46
+ iou_threshold = gr.Slider(label="IoU Threshold", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
47
+ yolov9_infer = gr.Button("Inference")
48
 
49
+ with gr.Column():
50
+ output_image = gr.Image(type="numpy", label="Output")
 
51
 
52
+ yolov9_infer.click(
53
+ fn=yolov9_inference,
54
+ inputs=[img_path, model_id, image_size, conf_threshold, iou_threshold],
55
+ outputs=[output_image]
56
+ )
57
 
58
+ return blocks
 
59
 
60
+ gradio_app = app()
61
 
62
+ # Display a title using HTML, centered.
63
+ gradio_app[''].add(
64
+ gr.HTML("""
65
+ <h1 style='text-align: center; margin-bottom: 20px;'>
66
+ YOLOv9 from PipYoloV9 on my data
67
+ </h1>
68
+ """)
69
+ )
70
 
71
+ # Launch the Gradio app, enabling debug mode for detailed error logs and server information.
72
+ gradio_app.launch(debug=True)