tstone87 commited on
Commit
fc38e58
·
verified ·
1 Parent(s): ef829ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -55
app.py CHANGED
@@ -6,18 +6,21 @@ import cv2
6
  import requests
7
  from ultralytics import YOLO
8
 
9
- # Remove extra CLI arguments (like "--import") from Spaces.
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
  # Load the YOLO11-pose model (auto-downloads if needed)
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
- def process_input(uploaded_file, youtube_link, image_url):
16
  """
17
- Process an uploaded file, a YouTube link, or an image URL for pose detection.
18
- Priority is: YouTube link > Image URL > Uploaded file.
 
 
19
  Returns a tuple:
20
- (download_file_path, display_file_path, status_message)
 
21
  """
22
  input_path = None
23
 
@@ -29,90 +32,98 @@ def process_input(uploaded_file, youtube_link, image_url):
29
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
30
  .order_by("resolution").desc().first()
31
  if stream is None:
32
- return None, None, "No suitable mp4 stream found."
33
  input_path = stream.download()
34
  except Exception as e:
35
- return None, None, f"Error downloading video: {e}"
36
  # Priority 2: Image URL
37
  elif image_url and image_url.strip():
38
  try:
39
  response = requests.get(image_url, stream=True)
40
  if response.status_code != 200:
41
- return None, None, f"Error downloading image: HTTP {response.status_code}"
42
  temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
43
  with open(temp_image_path, "wb") as f:
44
  f.write(response.content)
45
  input_path = temp_image_path
46
  except Exception as e:
47
- return None, None, f"Error downloading image: {e}"
48
  # Priority 3: Uploaded file
49
  elif uploaded_file is not None:
50
  input_path = uploaded_file.name
51
  else:
52
- return None, None, "Please provide a YouTube link, image URL, or upload a file."
53
 
54
- # Run pose detection (with save=True so annotated outputs are written to disk)
55
  try:
56
- results = model.predict(source=input_path, save=True)
 
57
  except Exception as e:
58
- return None, None, f"Error running prediction: {e}"
59
 
60
  output_path = None
61
  try:
62
- # If the result has a save_path attribute, use it.
63
  if hasattr(results[0], "save_path"):
64
  output_path = results[0].save_path
65
  else:
66
- # Otherwise, use plot() to get an annotated image and save it.
67
  annotated = results[0].plot() # returns a numpy array
68
  output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
69
  cv2.imwrite(output_path, annotated)
70
  except Exception as e:
71
- return None, None, f"Error processing the file: {e}"
72
 
73
- # Clean up temporary file if it was downloaded (from YouTube or Image URL)
74
- if (youtube_link or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
75
  os.remove(input_path)
76
 
77
- return output_path, output_path, "Success!"
78
 
79
- # Define the Gradio interface.
80
- with gr.Blocks() as demo:
81
- # Header image scaled down to 25% using HTML.
82
- gr.HTML("<div style='text-align: center;'><img src='crowdresult.jpg' style='width:25%;'></div>")
 
 
 
 
 
 
83
  gr.Markdown("## Pose Detection with YOLO11-pose")
84
- gr.Markdown("Choose one of the input methods below. The pose detection will run automatically when you provide an input, and the annotated result will be displayed below.")
85
-
86
- # Prepare output components (they will display the annotated result and a download link)
87
- output_file = gr.File(label="Download Annotated Output")
88
- output_display = gr.Image(label="Annotated Output")
89
- output_text = gr.Textbox(label="Status", interactive=False)
90
-
91
- # Create three tabs for the different input methods.
92
- with gr.Tabs():
93
- with gr.TabItem("Upload File"):
94
- file_input = gr.File(label="Upload Image/Video")
95
- # Automatically run detection when a file is uploaded.
96
- file_input.change(
97
- fn=process_input,
98
- inputs=[file_input, gr.State(None), gr.State(None)],
99
- outputs=[output_file, output_display, output_text]
100
- )
101
- with gr.TabItem("YouTube Link"):
102
- youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
103
- youtube_input.change(
104
- fn=process_input,
105
- inputs=[gr.State(None), youtube_input, gr.State(None)],
106
- outputs=[output_file, output_display, output_text]
107
- )
108
- with gr.TabItem("Image URL"):
109
- image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
110
- image_url_input.change(
111
- fn=process_input,
112
- inputs=[gr.State(None), gr.State(None), image_url_input],
113
- outputs=[output_file, output_display, output_text]
114
- )
 
 
 
 
 
115
 
116
- # Only launch the interface if the script is executed directly.
117
  if __name__ == "__main__":
118
  demo.launch()
 
6
  import requests
7
  from ultralytics import YOLO
8
 
9
+ # Remove extra CLI arguments that Spaces might pass.
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
  # Load the YOLO11-pose model (auto-downloads if needed)
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
+ def process_input(uploaded_file, youtube_link, image_url, sensitivity):
16
  """
17
+ Process input from one of the three methods (Upload, YouTube, Image URL).
18
+ Priority: YouTube link > Image URL > Uploaded file.
19
+ The sensitivity slider value is passed as the confidence threshold.
20
+
21
  Returns a tuple:
22
+ (download_file_path, display_file_path, status_message, dummy_state)
23
+ (The dummy_state is used because Gradio requires the same number of outputs.)
24
  """
25
  input_path = None
26
 
 
32
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
33
  .order_by("resolution").desc().first()
34
  if stream is None:
35
+ return None, None, "No suitable mp4 stream found.", ""
36
  input_path = stream.download()
37
  except Exception as e:
38
+ return None, None, f"Error downloading video: {e}", ""
39
  # Priority 2: Image URL
40
  elif image_url and image_url.strip():
41
  try:
42
  response = requests.get(image_url, stream=True)
43
  if response.status_code != 200:
44
+ return None, None, f"Error downloading image: HTTP {response.status_code}", ""
45
  temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
46
  with open(temp_image_path, "wb") as f:
47
  f.write(response.content)
48
  input_path = temp_image_path
49
  except Exception as e:
50
+ return None, None, f"Error downloading image: {e}", ""
51
  # Priority 3: Uploaded file
52
  elif uploaded_file is not None:
53
  input_path = uploaded_file.name
54
  else:
55
+ return None, None, "Please provide an input using one of the methods.", ""
56
 
 
57
  try:
58
+ # Pass the slider value as the confidence threshold.
59
+ results = model.predict(source=input_path, save=True, conf=sensitivity)
60
  except Exception as e:
61
+ return None, None, f"Error running prediction: {e}", ""
62
 
63
  output_path = None
64
  try:
 
65
  if hasattr(results[0], "save_path"):
66
  output_path = results[0].save_path
67
  else:
 
68
  annotated = results[0].plot() # returns a numpy array
69
  output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
70
  cv2.imwrite(output_path, annotated)
71
  except Exception as e:
72
+ return None, None, f"Error processing the file: {e}", ""
73
 
74
+ # Clean up the temporary input if it was downloaded.
75
+ if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
76
  os.remove(input_path)
77
 
78
+ return output_path, output_path, "Success!", ""
79
 
80
+ # Build the Gradio interface with custom CSS for the result image.
81
+ with gr.Blocks(css="""
82
+ .result_img > img {
83
+ width: 100%;
84
+ height: auto;
85
+ object-fit: contain;
86
+ }
87
+ """) as demo:
88
+ # Header with scaled image (25% width) and title.
89
+ gr.HTML("<div style='text-align:center;'><img src='crowdresult.jpg' style='width:25%;'/></div>")
90
  gr.Markdown("## Pose Detection with YOLO11-pose")
91
+
92
+ # Create two columns.
93
+ with gr.Row():
94
+ # Left column: Input tabs and sensitivity slider.
95
+ with gr.Column(scale=1):
96
+ with gr.Tabs():
97
+ with gr.TabItem("Upload File"):
98
+ file_input = gr.File(label="Upload Image/Video")
99
+ with gr.TabItem("YouTube Link"):
100
+ youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
101
+ with gr.TabItem("Image URL"):
102
+ image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
103
+ sensitivity_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5,
104
+ label="Sensitivity (Confidence Threshold)")
105
+ # Right column: Display result.
106
+ with gr.Column(scale=2):
107
+ output_display = gr.Image(label="Annotated Output", elem_classes="result_img")
108
+ output_file = gr.File(label="Download Annotated Output")
109
+ output_text = gr.Textbox(label="Status", interactive=False)
110
+
111
+ # Set up automatic triggers for each input type.
112
+ file_input.change(
113
+ fn=process_input,
114
+ inputs=[file_input, gr.State(""), gr.State(""), sensitivity_slider],
115
+ outputs=[output_file, output_display, output_text, gr.State()]
116
+ )
117
+ youtube_input.change(
118
+ fn=process_input,
119
+ inputs=[gr.State(None), youtube_input, gr.State(""), sensitivity_slider],
120
+ outputs=[output_file, output_display, output_text, gr.State()]
121
+ )
122
+ image_url_input.change(
123
+ fn=process_input,
124
+ inputs=[gr.State(None), gr.State(""), image_url_input, sensitivity_slider],
125
+ outputs=[output_file, output_display, output_text, gr.State()]
126
+ )
127
 
 
128
  if __name__ == "__main__":
129
  demo.launch()