tstone87 commited on
Commit
eadcbdc
·
verified ·
1 Parent(s): 04e7cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -4
app.py CHANGED
@@ -14,12 +14,15 @@ model = YOLO("yolo11n-pose.pt")
14
 
15
  def process_input(uploaded_file, youtube_link, image_url, sensitivity):
16
  """
17
- Process input from one of the three methods (Upload, YouTube, Image URL).
18
  Priority: YouTube link > Image URL > Uploaded file.
19
  The sensitivity slider value is passed as the confidence threshold.
20
 
21
- Returns a tuple:
22
- (download_file_path, display_file_path, status_message, dummy_state)
 
 
 
23
  """
24
  input_path = None
25
 
@@ -31,9 +34,112 @@ def process_input(uploaded_file, youtube_link, image_url, sensitivity):
31
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
32
  .order_by("resolution").desc().first()
33
  if stream is None:
34
- return None, None, "No suitable mp4 stream found.", ""
35
  input_path = stream.download()
36
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return None, None, f"Error downloading video: {e}", ""
38
  # Priority 2: Image URL
39
  elif image_url and image_url.strip():
 
14
 
15
  def process_input(uploaded_file, youtube_link, image_url, sensitivity):
16
  """
17
+ Process input from one of three methods (Upload, YouTube, Image URL).
18
  Priority: YouTube link > Image URL > Uploaded file.
19
  The sensitivity slider value is passed as the confidence threshold.
20
 
21
+ Returns a tuple of 4 items:
22
+ 1. download_file_path (for gr.File)
23
+ 2. image_result (for gr.Image) or None
24
+ 3. video_result (for gr.Video) or None
25
+ 4. status message
26
  """
27
  input_path = None
28
 
 
34
  stream = yt.streams.filter(file_extension='mp4', progressive=True)\
35
  .order_by("resolution").desc().first()
36
  if stream is None:
37
+ return None, None, None, "No suitable mp4 stream found."
38
  input_path = stream.download()
39
  except Exception as e:
40
+ return None, None, None, f"Error downloading video: {e}"
41
+ # Priority 2: Image URL
42
+ elif image_url and image_url.strip():
43
+ try:
44
+ response = requests.get(image_url, stream=True)
45
+ if response.status_code != 200:
46
+ return None, None, None, f"Error downloading image: HTTP {response.status_code}"
47
+ temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
48
+ with open(temp_image_path, "wb") as f:
49
+ f.write(response.content)
50
+ input_path = temp_image_path
51
+ except Exception as e:
52
+ return None, None, None, f"Error downloading image: {e}"
53
+ # Priority 3: Uploaded file
54
+ elif uploaded_file is not None:
55
+ input_path = uploaded_file.name
56
+ else:
57
+ return None, None, None, "Please provide an input using one of the methods."
58
+
59
+ try:
60
+ # Run prediction; pass slider value as confidence threshold.
61
+ results = model.predict(source=input_path, save=True, conf=sensitivity)
62
+ except Exception as e:
63
+ return None, None, None, f"Error running prediction: {e}"
64
+
65
+ output_path = None
66
+ try:
67
+ if hasattr(results[0], "save_path"):
68
+ output_path = results[0].save_path
69
+ else:
70
+ # If no save_path, generate annotated image using plot()
71
+ annotated = results[0].plot() # returns a numpy array
72
+ output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
73
+ cv2.imwrite(output_path, annotated)
74
+ except Exception as e:
75
+ return None, None, None, f"Error processing the file: {e}"
76
+
77
+ # Clean up temporary input if it was downloaded.
78
+ if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) \
79
+ and input_path and os.path.exists(input_path):
80
+ os.remove(input_path)
81
+
82
+ # Determine if output is video or image based on extension.
83
+ ext = os.path.splitext(output_path)[1].lower()
84
+ video_exts = [".mp4", ".mov", ".avi", ".webm"]
85
+ if ext in video_exts:
86
+ image_result = None
87
+ video_result = output_path
88
+ else:
89
+ image_result = output_path
90
+ video_result = None
91
+
92
+ return output_path, image_result, video_result, "Success!"
93
+
94
+ # Build the Gradio interface.
95
+ with gr.Blocks(css="""
96
+ .result_img > img {
97
+ width: 100%;
98
+ height: auto;
99
+ object-fit: contain;
100
+ }
101
+ """) as demo:
102
+ # Layout: two columns in a row.
103
+ with gr.Row():
104
+ # Left column: Header image, title, input method tabs, and shared sensitivity slider.
105
+ with gr.Column(scale=1):
106
+ gr.HTML("<div style='text-align:center;'><img src='https://huggingface.co/spaces/tstone87/stance-detection/resolve/main/crowdresult.jpg' style='width:25%;'/></div>")
107
+ gr.Markdown("## Pose Detection with YOLO11-pose")
108
+ with gr.Tabs():
109
+ with gr.TabItem("Upload File"):
110
+ file_input = gr.File(label="Upload Image/Video")
111
+ with gr.TabItem("YouTube Link"):
112
+ youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
113
+ with gr.TabItem("Image URL"):
114
+ image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
115
+ sensitivity_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5,
116
+ label="Sensitivity (Confidence Threshold)")
117
+ # Right column: Results displayed at the top.
118
+ with gr.Column(scale=2):
119
+ output_image = gr.Image(label="Annotated Output (Image)", elem_classes="result_img")
120
+ output_video = gr.Video(label="Annotated Output (Video)")
121
+ output_file = gr.File(label="Download Annotated Output")
122
+ output_text = gr.Textbox(label="Status", interactive=False)
123
+
124
+ # Set up automatic triggers for each input type.
125
+ file_input.change(
126
+ fn=process_input,
127
+ inputs=[file_input, gr.State(""), gr.State(""), sensitivity_slider],
128
+ outputs=[output_file, output_image, output_video, output_text]
129
+ )
130
+ youtube_input.change(
131
+ fn=process_input,
132
+ inputs=[gr.State(None), youtube_input, gr.State(""), sensitivity_slider],
133
+ outputs=[output_file, output_image, output_video, output_text]
134
+ )
135
+ image_url_input.change(
136
+ fn=process_input,
137
+ inputs=[gr.State(None), gr.State(""), image_url_input, sensitivity_slider],
138
+ outputs=[output_file, output_image, output_video, output_text]
139
+ )
140
+
141
+ if __name__ == "__main__":
142
+ demo.launch()
143
  return None, None, f"Error downloading video: {e}", ""
144
  # Priority 2: Image URL
145
  elif image_url and image_url.strip():