tstone87 commited on
Commit
7aa805e
·
verified ·
1 Parent(s): 30fd2ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -82
app.py CHANGED
@@ -6,129 +6,120 @@ import cv2
6
  import requests
7
  from ultralytics import YOLO
8
 
9
- # Remove extra CLI arguments that Spaces might pass.
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
- # Load the YOLO11-pose model (auto-downloads if needed)
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
  def process_input(uploaded_file, youtube_link, image_url, sensitivity):
16
  """
17
- Process input from one of three methods (Upload, YouTube, Image URL).
18
  Priority: YouTube link > Image URL > Uploaded file.
19
- The sensitivity slider value is passed as the confidence threshold.
20
-
21
- For video files (mp4, mov, avi, webm), we process the video frame-by-frame
22
- using OpenCV. For images, we use normal prediction.
23
-
24
- Returns a tuple:
25
- - download_file_path (for gr.File)
26
- - image_result (for gr.Image) or None
27
- - video_result (for gr.Video) or None
28
- - status message
29
  """
30
  input_path = None
 
31
 
32
  # Priority 1: YouTube link
33
  if youtube_link and youtube_link.strip():
34
  try:
35
- from pytube import YouTube
36
  yt = YouTube(youtube_link)
37
  stream = yt.streams.filter(file_extension='mp4', progressive=True).order_by("resolution").desc().first()
38
- if stream is None:
39
  return None, None, None, "No suitable mp4 stream found."
40
- input_path = stream.download()
 
 
 
41
  except Exception as e:
42
- return None, None, None, f"Error downloading video: {e}"
 
43
  # Priority 2: Image URL
44
  elif image_url and image_url.strip():
45
  try:
46
- response = requests.get(image_url, stream=True)
47
- if response.status_code != 200:
48
- return None, None, None, f"Error downloading image: HTTP {response.status_code}"
49
- temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
50
- with open(temp_image_path, "wb") as f:
51
  f.write(response.content)
52
- input_path = temp_image_path
 
53
  except Exception as e:
54
- return None, None, None, f"Error downloading image: {e}"
 
55
  # Priority 3: Uploaded file
56
  elif uploaded_file is not None:
57
  input_path = uploaded_file.name
58
  else:
59
- return None, None, None, "Please provide an input using one of the methods."
60
 
61
- # Determine if input is a video (by extension).
62
- ext_input = os.path.splitext(input_path)[1].lower()
63
  video_exts = [".mp4", ".mov", ".avi", ".webm"]
64
-
65
  output_path = None
66
 
67
- if ext_input in video_exts:
68
- # Process video frame-by-frame using OpenCV.
69
- try:
70
  cap = cv2.VideoCapture(input_path)
71
  if not cap.isOpened():
72
  return None, None, None, "Error opening video file."
 
73
  fps = cap.get(cv2.CAP_PROP_FPS)
74
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
75
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
76
- frames = []
 
 
 
 
 
77
  while True:
78
  ret, frame = cap.read()
79
  if not ret:
80
  break
81
- # Run detection on the frame.
82
- # Note: model.predict() accepts an image (numpy array) as source.
83
- result = model.predict(source=frame, conf=sensitivity)[0]
84
- annotated_frame = result.plot() # returns an annotated frame (numpy array)
85
- frames.append(annotated_frame)
 
 
 
 
86
  cap.release()
87
- if not frames:
88
- return None, None, None, "No detections were returned from video processing."
89
- # Write annotated frames to a temporary video file.
90
- temp_video_path = os.path.join(tempfile.gettempdir(), "annotated_video.mp4")
91
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
92
- out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height))
93
- for frame in frames:
94
- out.write(frame)
95
  out.release()
96
- output_path = temp_video_path
97
- except Exception as e:
98
- return None, None, None, f"Error processing video: {e}"
99
- else:
100
- # Process as an image.
101
- try:
102
- results = model.predict(source=input_path, save=True, conf=sensitivity)
103
- except Exception as e:
104
- return None, None, None, f"Error running prediction: {e}"
105
- try:
106
- if not results or len(results) == 0:
107
- return None, None, None, "No detections were returned."
108
- if hasattr(results[0], "save_path"):
109
- output_path = results[0].save_path
110
- else:
111
- annotated = results[0].plot() # returns a numpy array
112
- output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
113
- cv2.imwrite(output_path, annotated)
114
- except Exception as e:
115
- return None, None, None, f"Error processing the file: {e}"
116
 
117
- # Clean up temporary input if downloaded.
118
- if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
119
- os.remove(input_path)
 
 
 
 
 
120
 
121
- # Set outputs based on output file extension.
122
- ext_output = os.path.splitext(output_path)[1].lower()
123
- if ext_output in video_exts:
124
- image_result = None
125
- video_result = output_path
126
- else:
127
- image_result = output_path
128
- video_result = None
129
 
130
- return output_path, image_result, video_result, "Success!"
 
 
 
 
 
 
 
131
 
 
132
  with gr.Blocks(css="""
133
  .result_img > img {
134
  width: 100%;
@@ -137,7 +128,6 @@ with gr.Blocks(css="""
137
  }
138
  """) as demo:
139
  with gr.Row():
140
- # Left Column: Header image, title, input tabs, and sensitivity slider.
141
  with gr.Column(scale=1):
142
  gr.HTML("<div style='text-align:center;'><img src='https://huggingface.co/spaces/tstone87/stance-detection/resolve/main/crowdresult.jpg' style='width:25%;'/></div>")
143
  gr.Markdown("## Pose Detection with YOLO11-pose")
@@ -148,9 +138,8 @@ with gr.Blocks(css="""
148
  youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
149
  with gr.TabItem("Image URL"):
150
  image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
151
- sensitivity_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.15,
152
- label="Sensitivity (Confidence Threshold)")
153
- # Right Column: Results display at the top.
154
  with gr.Column(scale=2):
155
  output_image = gr.Image(label="Annotated Output (Image)", elem_classes="result_img")
156
  output_video = gr.Video(label="Annotated Output (Video)")
@@ -174,4 +163,4 @@ with gr.Blocks(css="""
174
  )
175
 
176
  if __name__ == "__main__":
177
- demo.launch()
 
6
  import requests
7
  from ultralytics import YOLO
8
 
9
+ # Remove extra CLI arguments that Spaces might pass
10
  sys.argv = [arg for arg in sys.argv if arg != "--import"]
11
 
12
+ # Load the YOLO11-pose model
13
  model = YOLO("yolo11n-pose.pt")
14
 
15
  def process_input(uploaded_file, youtube_link, image_url, sensitivity):
16
  """
17
+ Process input from Upload, YouTube, or Image URL.
18
  Priority: YouTube link > Image URL > Uploaded file.
19
+ Sensitivity is the confidence threshold.
 
 
 
 
 
 
 
 
 
20
  """
21
  input_path = None
22
+ temp_files = []
23
 
24
  # Priority 1: YouTube link
25
  if youtube_link and youtube_link.strip():
26
  try:
27
+ from pytubefix import YouTube # Use pytubefix instead of pytube
28
  yt = YouTube(youtube_link)
29
  stream = yt.streams.filter(file_extension='mp4', progressive=True).order_by("resolution").desc().first()
30
+ if not stream:
31
  return None, None, None, "No suitable mp4 stream found."
32
+ temp_path = os.path.join(tempfile.gettempdir(), f"yt_{os.urandom(8).hex()}.mp4")
33
+ stream.download(output_path=tempfile.gettempdir(), filename=os.path.basename(temp_path))
34
+ input_path = temp_path
35
+ temp_files.append(input_path)
36
  except Exception as e:
37
+ return None, None, None, f"Error downloading YouTube video: {str(e)}"
38
+
39
  # Priority 2: Image URL
40
  elif image_url and image_url.strip():
41
  try:
42
+ response = requests.get(image_url, stream=True, timeout=10)
43
+ response.raise_for_status()
44
+ temp_path = os.path.join(tempfile.gettempdir(), f"img_{os.urandom(8).hex()}.jpg")
45
+ with open(temp_path, "wb") as f:
 
46
  f.write(response.content)
47
+ input_path = temp_path
48
+ temp_files.append(input_path)
49
  except Exception as e:
50
+ return None, None, None, f"Error downloading image: {str(e)}"
51
+
52
  # Priority 3: Uploaded file
53
  elif uploaded_file is not None:
54
  input_path = uploaded_file.name
55
  else:
56
+ return None, None, None, "Please provide an input."
57
 
58
+ # Process the file
59
+ ext = os.path.splitext(input_path)[1].lower()
60
  video_exts = [".mp4", ".mov", ".avi", ".webm"]
 
61
  output_path = None
62
 
63
+ try:
64
+ if ext in video_exts:
65
+ # Video processing
66
  cap = cv2.VideoCapture(input_path)
67
  if not cap.isOpened():
68
  return None, None, None, "Error opening video file."
69
+
70
  fps = cap.get(cv2.CAP_PROP_FPS)
71
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
72
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
73
+
74
+ # Create output video
75
+ output_path = os.path.join(tempfile.gettempdir(), f"out_{os.urandom(8).hex()}.mp4")
76
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
77
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
78
+
79
  while True:
80
  ret, frame = cap.read()
81
  if not ret:
82
  break
83
+
84
+ # Convert BGR to RGB for YOLO
85
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+ results = model.predict(source=frame_rgb, conf=sensitivity)[0]
87
+ annotated_frame = results.plot()
88
+ # Convert back to BGR for video writing
89
+ annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
90
+ out.write(annotated_frame_bgr)
91
+
92
  cap.release()
 
 
 
 
 
 
 
 
93
  out.release()
94
+ temp_files.append(output_path)
95
+
96
+ if os.path.getsize(output_path) == 0:
97
+ return None, None, None, "Error: Output video is empty."
98
+
99
+ return output_path, None, output_path, "Video processed successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ else:
102
+ # Image processing
103
+ results = model.predict(source=input_path, conf=sensitivity)[0]
104
+ annotated = results.plot()
105
+ output_path = os.path.join(tempfile.gettempdir(), f"out_{os.urandom(8).hex()}.jpg")
106
+ cv2.imwrite(output_path, annotated)
107
+ temp_files.append(output_path)
108
+ return output_path, output_path, None, "Image processed successfully!"
109
 
110
+ except Exception as e:
111
+ return None, None, None, f"Processing error: {str(e)}"
 
 
 
 
 
 
112
 
113
+ finally:
114
+ # Clean up temporary input files (but keep output for download)
115
+ for f in temp_files[:-1]: # Exclude output_path
116
+ if f and os.path.exists(f):
117
+ try:
118
+ os.remove(f)
119
+ except:
120
+ pass
121
 
122
+ # Gradio interface remains mostly the same
123
  with gr.Blocks(css="""
124
  .result_img > img {
125
  width: 100%;
 
128
  }
129
  """) as demo:
130
  with gr.Row():
 
131
  with gr.Column(scale=1):
132
  gr.HTML("<div style='text-align:center;'><img src='https://huggingface.co/spaces/tstone87/stance-detection/resolve/main/crowdresult.jpg' style='width:25%;'/></div>")
133
  gr.Markdown("## Pose Detection with YOLO11-pose")
 
138
  youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
139
  with gr.TabItem("Image URL"):
140
  image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
141
+ sensitivity_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2,
142
+ label="Sensitivity (Confidence Threshold)")
 
143
  with gr.Column(scale=2):
144
  output_image = gr.Image(label="Annotated Output (Image)", elem_classes="result_img")
145
  output_video = gr.Video(label="Annotated Output (Video)")
 
163
  )
164
 
165
  if __name__ == "__main__":
166
+ demo.launch()