tstone87 commited on
Commit
ef829ca
·
verified ·
1 Parent(s): 8e825ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -26
app.py CHANGED
@@ -9,14 +9,15 @@ from ultralytics import YOLO
9
  # Remove extra CLI arguments (like "--import") from Spaces.
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
- # Load the YOLO11-pose model (will auto-download if needed)
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
  def process_input(uploaded_file, youtube_link, image_url):
16
  """
17
  Process an uploaded file, a YouTube link, or an image URL for pose detection.
18
- Returns a tuple: (download_file_path, display_file_path, status_message).
19
- Priority: YouTube link > Image URL > Uploaded file.
 
20
  """
21
  input_path = None
22
 
@@ -58,48 +59,60 @@ def process_input(uploaded_file, youtube_link, image_url):
58
 
59
  output_path = None
60
  try:
61
- # If the results object has a save_path attribute, use it.
62
  if hasattr(results[0], "save_path"):
63
  output_path = results[0].save_path
64
  else:
65
- # Otherwise, generate an annotated image using plot() and save it manually.
66
  annotated = results[0].plot() # returns a numpy array
67
  output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
68
  cv2.imwrite(output_path, annotated)
69
  except Exception as e:
70
  return None, None, f"Error processing the file: {e}"
71
 
72
- # Clean up the temporary input file if downloaded.
73
  if (youtube_link or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
74
  os.remove(input_path)
75
 
76
- # Return the same output path for both download and display.
77
  return output_path, output_path, "Success!"
78
 
79
  # Define the Gradio interface.
80
  with gr.Blocks() as demo:
81
- gr.Markdown("# Pose Detection with YOLO11-pose")
82
- gr.Image(value="crowdresult.jpg", label="Crowd Result", interactive=False)
83
- gr.Markdown("Upload an image/video, provide an image URL, or supply a YouTube link to detect human poses.")
84
-
85
- with gr.Row():
86
- file_input = gr.File(label="Upload Image/Video")
87
- with gr.Row():
88
- youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
89
- image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
90
-
91
- # Three outputs: one for file download, one for immediate display, and one for status text.
92
  output_file = gr.File(label="Download Annotated Output")
93
  output_display = gr.Image(label="Annotated Output")
94
  output_text = gr.Textbox(label="Status", interactive=False)
95
- run_button = gr.Button("Run Pose Detection")
96
-
97
- run_button.click(
98
- process_input,
99
- inputs=[file_input, youtube_input, image_url_input],
100
- outputs=[output_file, output_display, output_text]
101
- )
102
 
103
- # Only launch the interface if executed directly.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if __name__ == "__main__":
105
  demo.launch()
 
9
  # Remove extra CLI arguments (like "--import") from Spaces.
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
+ # Load the YOLO11-pose model (auto-downloads if needed)
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
  def process_input(uploaded_file, youtube_link, image_url):
16
  """
17
  Process an uploaded file, a YouTube link, or an image URL for pose detection.
18
+ Priority is: YouTube link > Image URL > Uploaded file.
19
+ Returns a tuple:
20
+ (download_file_path, display_file_path, status_message)
21
  """
22
  input_path = None
23
 
 
59
 
60
  output_path = None
61
  try:
62
+ # If the result has a save_path attribute, use it.
63
  if hasattr(results[0], "save_path"):
64
  output_path = results[0].save_path
65
  else:
66
+ # Otherwise, use plot() to get an annotated image and save it.
67
  annotated = results[0].plot() # returns a numpy array
68
  output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
69
  cv2.imwrite(output_path, annotated)
70
  except Exception as e:
71
  return None, None, f"Error processing the file: {e}"
72
 
73
+ # Clean up temporary file if it was downloaded (from YouTube or Image URL)
74
  if (youtube_link or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
75
  os.remove(input_path)
76
 
 
77
  return output_path, output_path, "Success!"
78
 
79
  # Define the Gradio interface.
80
  with gr.Blocks() as demo:
81
+ # Header image scaled down to 25% using HTML.
82
+ gr.HTML("<div style='text-align: center;'><img src='crowdresult.jpg' style='width:25%;'></div>")
83
+ gr.Markdown("## Pose Detection with YOLO11-pose")
84
+ gr.Markdown("Choose one of the input methods below. The pose detection will run automatically when you provide an input, and the annotated result will be displayed below.")
85
+
86
+ # Prepare output components (they will display the annotated result and a download link)
 
 
 
 
 
87
  output_file = gr.File(label="Download Annotated Output")
88
  output_display = gr.Image(label="Annotated Output")
89
  output_text = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
90
 
91
+ # Create three tabs for the different input methods.
92
+ with gr.Tabs():
93
+ with gr.TabItem("Upload File"):
94
+ file_input = gr.File(label="Upload Image/Video")
95
+ # Automatically run detection when a file is uploaded.
96
+ file_input.change(
97
+ fn=process_input,
98
+ inputs=[file_input, gr.State(None), gr.State(None)],
99
+ outputs=[output_file, output_display, output_text]
100
+ )
101
+ with gr.TabItem("YouTube Link"):
102
+ youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
103
+ youtube_input.change(
104
+ fn=process_input,
105
+ inputs=[gr.State(None), youtube_input, gr.State(None)],
106
+ outputs=[output_file, output_display, output_text]
107
+ )
108
+ with gr.TabItem("Image URL"):
109
+ image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
110
+ image_url_input.change(
111
+ fn=process_input,
112
+ inputs=[gr.State(None), gr.State(None), image_url_input],
113
+ outputs=[output_file, output_display, output_text]
114
+ )
115
+
116
+ # Only launch the interface if the script is executed directly.
117
  if __name__ == "__main__":
118
  demo.launch()