tstone87 commited on
Commit
7afb01c
·
verified ·
1 Parent(s): 3fd2f76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -27
app.py CHANGED
@@ -3,29 +3,28 @@ import gradio as gr
3
  import os
4
  import tempfile
5
  import cv2
 
6
  from ultralytics import YOLO
7
 
8
- # Optionally remove extra CLI arguments that Spaces might pass
9
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
10
 
11
- # Load the YOLO11-pose model (it will auto-download if not present)
12
  model = YOLO("yolo11n-pose.pt")
13
 
14
- def process_input(uploaded_file, youtube_link):
15
  """
16
- Process an uploaded file or a YouTube link to perform pose detection.
17
- Returns a tuple: (annotated_file_path, status_message).
18
- If an error occurs, annotated_file_path is None and status_message describes the error.
19
  """
20
- error_message = ""
21
  input_path = None
22
 
23
- # Check for input: either a YouTube link or an uploaded file.
24
  if youtube_link and youtube_link.strip():
25
  try:
26
  from pytube import YouTube
27
  yt = YouTube(youtube_link)
28
- # Get the highest resolution progressive mp4 stream
29
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
30
  .order_by("resolution").desc().first()
31
  if stream is None:
@@ -33,54 +32,70 @@ def process_input(uploaded_file, youtube_link):
33
  input_path = stream.download()
34
  except Exception as e:
35
  return None, f"Error downloading video: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  elif uploaded_file is not None:
37
  input_path = uploaded_file.name
38
  else:
39
- return None, "Please provide an uploaded file or a YouTube link."
40
 
41
- # Run pose detection (with save=True so that outputs are written to disk)
42
  try:
43
  results = model.predict(source=input_path, save=True)
44
  except Exception as e:
45
  return None, f"Error running prediction: {e}"
46
 
47
- # Try to get the annotated output file:
48
  output_path = None
49
  try:
50
- # Some YOLO versions may offer a 'save_path' attribute.
51
  if hasattr(results[0], "save_path"):
52
  output_path = results[0].save_path
53
  else:
54
- # Fallback: generate the annotated image using result.plot()
55
- annotated = results[0].plot() # returns a numpy array with annotations
56
- # Save the annotated image to a temporary file.
57
  output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
58
  cv2.imwrite(output_path, annotated)
59
  except Exception as e:
60
  return None, f"Error processing the file: {e}"
61
 
62
- # Clean up the downloaded video file if it came from YouTube.
63
- if youtube_link and input_path and os.path.exists(input_path):
64
  os.remove(input_path)
65
-
66
  return output_path, "Success!"
67
 
68
- # Define the Gradio interface with two outputs: one for the file and one for status text.
69
  with gr.Blocks() as demo:
 
70
  gr.Markdown("# Pose Detection with YOLO11-pose")
71
- gr.Markdown("Upload an image/video or provide a YouTube link to detect human poses.")
72
-
 
73
  with gr.Row():
74
  file_input = gr.File(label="Upload Image/Video")
75
- youtube_input = gr.Textbox(label="Or enter a YouTube link", placeholder="https://...")
76
-
 
 
77
  output_file = gr.File(label="Download Annotated Output")
78
  output_text = gr.Textbox(label="Status", interactive=False)
79
  run_button = gr.Button("Run Pose Detection")
80
-
81
- run_button.click(process_input, inputs=[file_input, youtube_input],
82
  outputs=[output_file, output_text])
83
 
84
- # Only launch the app if the script is executed directly.
85
  if __name__ == "__main__":
86
  demo.launch()
 
3
  import os
4
  import tempfile
5
  import cv2
6
+ import requests
7
  from ultralytics import YOLO
8
 
9
+ # Remove extra CLI arguments (like "--import") from Spaces.
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
+ # Load the YOLO11-pose model (will auto-download if needed)
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
+ def process_input(uploaded_file, youtube_link, image_url):
16
  """
17
+ Process an uploaded file, a YouTube link, or an image URL for pose detection.
18
+ Returns a tuple: (annotated_file_path, status_message).
19
+ Priority is given to the YouTube link first, then image URL, then the uploaded file.
20
  """
 
21
  input_path = None
22
 
23
+ # Priority 1: YouTube link
24
  if youtube_link and youtube_link.strip():
25
  try:
26
  from pytube import YouTube
27
  yt = YouTube(youtube_link)
 
28
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
29
  .order_by("resolution").desc().first()
30
  if stream is None:
 
32
  input_path = stream.download()
33
  except Exception as e:
34
  return None, f"Error downloading video: {e}"
35
+ # Priority 2: Image URL
36
+ elif image_url and image_url.strip():
37
+ try:
38
+ response = requests.get(image_url, stream=True)
39
+ if response.status_code != 200:
40
+ return None, f"Error downloading image: HTTP {response.status_code}"
41
+ # Save the downloaded image to a temporary file.
42
+ temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
43
+ with open(temp_image_path, "wb") as f:
44
+ f.write(response.content)
45
+ input_path = temp_image_path
46
+ except Exception as e:
47
+ return None, f"Error downloading image: {e}"
48
+ # Priority 3: Uploaded file
49
  elif uploaded_file is not None:
50
  input_path = uploaded_file.name
51
  else:
52
+ return None, "Please provide a YouTube link, image URL, or upload a file."
53
 
54
+ # Run pose detection (with save=True so annotated results are written to disk)
55
  try:
56
  results = model.predict(source=input_path, save=True)
57
  except Exception as e:
58
  return None, f"Error running prediction: {e}"
59
 
 
60
  output_path = None
61
  try:
62
+ # If the results object has a save_path attribute, use it.
63
  if hasattr(results[0], "save_path"):
64
  output_path = results[0].save_path
65
  else:
66
+ # Otherwise, generate an annotated image using plot() and save it manually.
67
+ annotated = results[0].plot() # returns a numpy array
 
68
  output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
69
  cv2.imwrite(output_path, annotated)
70
  except Exception as e:
71
  return None, f"Error processing the file: {e}"
72
 
73
+ # If the input came from YouTube or image URL, remove the temporary file.
74
+ if (youtube_link or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
75
  os.remove(input_path)
76
+
77
  return output_path, "Success!"
78
 
79
+ # Define the Gradio interface.
80
  with gr.Blocks() as demo:
81
+ # Display the default image (crowdresult.jpg) at the top.
82
  gr.Markdown("# Pose Detection with YOLO11-pose")
83
+ gr.Image(value="crowdresult.jpg", label="Crowd Result", interactive=False)
84
+ gr.Markdown("Upload an image/video, provide an image URL, or supply a YouTube link to detect human poses.")
85
+
86
  with gr.Row():
87
  file_input = gr.File(label="Upload Image/Video")
88
+ with gr.Row():
89
+ youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
90
+ image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
91
+
92
  output_file = gr.File(label="Download Annotated Output")
93
  output_text = gr.Textbox(label="Status", interactive=False)
94
  run_button = gr.Button("Run Pose Detection")
95
+
96
+ run_button.click(process_input, inputs=[file_input, youtube_input, image_url_input],
97
  outputs=[output_file, output_text])
98
 
99
+ # Only launch the interface if executed directly.
100
  if __name__ == "__main__":
101
  demo.launch()