tstone87 commited on
Commit
5260c34
·
verified ·
1 Parent(s): 2a46a78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -21
app.py CHANGED
@@ -1,58 +1,86 @@
 
1
  import gradio as gr
2
  import os
 
 
3
  from ultralytics import YOLO
4
 
5
- # Load the YOLO11-pose model (auto-downloads if not present)
 
 
 
6
  model = YOLO("yolo11n-pose.pt")
7
 
8
  def process_input(uploaded_file, youtube_link):
9
  """
10
  Process an uploaded file or a YouTube link to perform pose detection.
11
- Returns the path to the annotated output.
 
12
  """
 
 
 
 
13
  if youtube_link and youtube_link.strip():
14
  try:
15
  from pytube import YouTube
16
  yt = YouTube(youtube_link)
 
17
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
18
  .order_by("resolution").desc().first()
19
  if stream is None:
20
- return "No suitable mp4 stream found."
21
  input_path = stream.download()
22
  except Exception as e:
23
- return f"Error downloading video: {e}"
24
  elif uploaded_file is not None:
25
  input_path = uploaded_file.name
26
  else:
27
- return "Please provide an uploaded file or a YouTube link."
28
-
29
- # Run pose detection and save the annotated output.
30
- results = model.predict(source=input_path, save=True)
31
-
32
  try:
33
- output_path = results[0].save_path
34
  except Exception as e:
35
- return f"Error processing the file: {e}"
36
-
37
- # Optionally remove the downloaded video if applicable.
38
- if youtube_link and os.path.exists(input_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  os.remove(input_path)
40
 
41
- return output_path
42
-
43
- # Define the Gradio Blocks interface as a global variable.
44
- demo = gr.Blocks()
45
 
46
- with demo:
 
47
  gr.Markdown("# Pose Detection with YOLO11-pose")
48
  gr.Markdown("Upload an image/video or provide a YouTube link to detect human poses.")
 
49
  with gr.Row():
50
  file_input = gr.File(label="Upload Image/Video")
51
  youtube_input = gr.Textbox(label="Or enter a YouTube link", placeholder="https://...")
 
52
  output_file = gr.File(label="Download Annotated Output")
 
53
  run_button = gr.Button("Run Pose Detection")
54
- run_button.click(process_input, inputs=[file_input, youtube_input], outputs=output_file)
 
 
55
 
56
- # Only launch the interface if this file is executed directly.
57
  if __name__ == "__main__":
58
  demo.launch()
 
1
+ import sys
2
  import gradio as gr
3
  import os
4
+ import tempfile
5
+ import cv2
6
  from ultralytics import YOLO
7
 
8
+ # Optionally remove extra CLI arguments that Spaces might pass
9
+ sys.argv = [arg for arg in sys.argv if arg != "--import"]
10
+
11
+ # Load the YOLO11-pose model (it will auto-download if not present)
12
  model = YOLO("yolo11n-pose.pt")
13
 
14
  def process_input(uploaded_file, youtube_link):
15
  """
16
  Process an uploaded file or a YouTube link to perform pose detection.
17
+ Returns a tuple: (annotated_file_path, status_message).
18
+ If an error occurs, annotated_file_path is None and status_message describes the error.
19
  """
20
+ error_message = ""
21
+ input_path = None
22
+
23
+ # Check for input: either a YouTube link or an uploaded file.
24
  if youtube_link and youtube_link.strip():
25
  try:
26
  from pytube import YouTube
27
  yt = YouTube(youtube_link)
28
+ # Get the highest resolution progressive mp4 stream
29
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
30
  .order_by("resolution").desc().first()
31
  if stream is None:
32
+ return None, "No suitable mp4 stream found."
33
  input_path = stream.download()
34
  except Exception as e:
35
+ return None, f"Error downloading video: {e}"
36
  elif uploaded_file is not None:
37
  input_path = uploaded_file.name
38
  else:
39
+ return None, "Please provide an uploaded file or a YouTube link."
40
+
41
+ # Run pose detection (with save=True so that outputs are written to disk)
 
 
42
  try:
43
+ results = model.predict(source=input_path, save=True)
44
  except Exception as e:
45
+ return None, f"Error running prediction: {e}"
46
+
47
+ # Try to get the annotated output file:
48
+ output_path = None
49
+ try:
50
+ # Some YOLO versions may offer a 'save_path' attribute.
51
+ if hasattr(results[0], "save_path"):
52
+ output_path = results[0].save_path
53
+ else:
54
+ # Fallback: generate the annotated image using result.plot()
55
+ annotated = results[0].plot() # returns a numpy array with annotations
56
+ # Save the annotated image to a temporary file.
57
+ output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
58
+ cv2.imwrite(output_path, annotated)
59
+ except Exception as e:
60
+ return None, f"Error processing the file: {e}"
61
+
62
+ # Clean up the downloaded video file if it came from YouTube.
63
+ if youtube_link and input_path and os.path.exists(input_path):
64
  os.remove(input_path)
65
 
66
+ return output_path, "Success!"
 
 
 
67
 
68
+ # Define the Gradio interface with two outputs: one for the file and one for status text.
69
+ with gr.Blocks() as demo:
70
  gr.Markdown("# Pose Detection with YOLO11-pose")
71
  gr.Markdown("Upload an image/video or provide a YouTube link to detect human poses.")
72
+
73
  with gr.Row():
74
  file_input = gr.File(label="Upload Image/Video")
75
  youtube_input = gr.Textbox(label="Or enter a YouTube link", placeholder="https://...")
76
+
77
  output_file = gr.File(label="Download Annotated Output")
78
+ output_text = gr.Textbox(label="Status", interactive=False)
79
  run_button = gr.Button("Run Pose Detection")
80
+
81
+ run_button.click(process_input, inputs=[file_input, youtube_input],
82
+ outputs=[output_file, output_text])
83
 
84
+ # Only launch the app if the script is executed directly.
85
  if __name__ == "__main__":
86
  demo.launch()