File size: 7,526 Bytes
5260c34
7c90da7
 
5260c34
 
7afb01c
7c90da7
 
fc38e58
5260c34
 
ef829ca
7c90da7
 
fc38e58
7c90da7
eadcbdc
fc38e58
 
 
680b8e8
 
 
 
 
 
 
 
7c90da7
5260c34
 
7afb01c
7c90da7
 
 
 
f34aee4
7c90da7
eadcbdc
7c90da7
 
eadcbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680b8e8
 
 
eadcbdc
 
 
680b8e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f34aee4
eadcbdc
 
680b8e8
 
 
eadcbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f34aee4
eadcbdc
 
 
 
 
 
 
 
 
 
cff6cbb
eadcbdc
f34aee4
eadcbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c90da7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import sys
import gradio as gr
import os
import tempfile
import cv2
import requests
from ultralytics import YOLO

# Remove extra CLI arguments that Spaces might pass.
sys.argv = [arg for arg in sys.argv if arg != "--import"]

# Load the YOLO11-pose model (auto-downloads if needed)
model = YOLO("yolo11n-pose.pt")

def process_input(uploaded_file, youtube_link, image_url, sensitivity):
    """
    Process input from one of three methods (Upload, YouTube, Image URL).
    Priority: YouTube link > Image URL > Uploaded file.
    The sensitivity slider value is passed as the confidence threshold.
    
    For video files (mp4, mov, avi, webm), we use streaming mode to obtain annotated frames and encode them into a video.
    For images, we use the normal prediction and either use the built‑in save_path or plot() method.
    
    Returns a tuple:
      - download_file_path (for gr.File)
      - image_result (for gr.Image) or None
      - video_result (for gr.Video) or None
      - status message
    """
    input_path = None

    # Priority 1: YouTube link
    if youtube_link and youtube_link.strip():
        try:
            from pytube import YouTube
            yt = YouTube(youtube_link)
            stream = yt.streams.filter(file_extension='mp4', progressive=True).order_by("resolution").desc().first()
            if stream is None:
                return None, None, None, "No suitable mp4 stream found."
            input_path = stream.download()
        except Exception as e:
            return None, None, None, f"Error downloading video: {e}"
    # Priority 2: Image URL
    elif image_url and image_url.strip():
        try:
            response = requests.get(image_url, stream=True)
            if response.status_code != 200:
                return None, None, None, f"Error downloading image: HTTP {response.status_code}"
            temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
            with open(temp_image_path, "wb") as f:
                f.write(response.content)
            input_path = temp_image_path
        except Exception as e:
            return None, None, None, f"Error downloading image: {e}"
    # Priority 3: Uploaded file
    elif uploaded_file is not None:
        input_path = uploaded_file.name
    else:
        return None, None, None, "Please provide an input using one of the methods."

    # Determine if input is a video (by extension).
    ext_input = os.path.splitext(input_path)[1].lower()
    video_exts = [".mp4", ".mov", ".avi", ".webm"]

    output_path = None

    if ext_input in video_exts:
        # Process video using streaming mode.
        try:
            # Open video to get properties.
            cap = cv2.VideoCapture(input_path)
            if not cap.isOpened():
                return None, None, None, "Error opening video file."
            fps = cap.get(cv2.CAP_PROP_FPS)
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            cap.release()

            # Use streaming mode to process each frame.
            frames = []
            for result in model.predict(source=input_path, stream=True, conf=sensitivity):
                # result.plot() returns an annotated frame (numpy array)
                annotated_frame = result.plot()
                frames.append(annotated_frame)
            if not frames:
                return None, None, None, "No detections were returned from video streaming."
            # Write frames to a temporary video file.
            temp_video_path = os.path.join(tempfile.gettempdir(), "annotated_video.mp4")
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height))
            for frame in frames:
                out.write(frame)
            out.release()
            output_path = temp_video_path
        except Exception as e:
            return None, None, None, f"Error processing video: {e}"
    else:
        # Process as an image.
        try:
            results = model.predict(source=input_path, save=True, conf=sensitivity)
        except Exception as e:
            return None, None, None, f"Error running prediction: {e}"

        try:
            if not results or len(results) == 0:
                return None, None, None, "No detections were returned."
            if hasattr(results[0], "save_path"):
                output_path = results[0].save_path
            else:
                annotated = results[0].plot()  # returns a numpy array
                output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
                cv2.imwrite(output_path, annotated)
        except Exception as e:
            return None, None, None, f"Error processing the file: {e}"

    # Clean up temporary input if downloaded.
    if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
        os.remove(input_path)

    # Set outputs based on output file extension.
    ext_output = os.path.splitext(output_path)[1].lower()
    if ext_output in video_exts:
        image_result = None
        video_result = output_path
    else:
        image_result = output_path
        video_result = None

    return output_path, image_result, video_result, "Success!"

with gr.Blocks(css="""
.result_img > img {
  width: 100%;
  height: auto;
  object-fit: contain;
}
""") as demo:
    with gr.Row():
        # Left Column: Header image, title, input tabs, and sensitivity slider.
        with gr.Column(scale=1):
            gr.HTML("<div style='text-align:center;'><img src='https://huggingface.co/spaces/tstone87/stance-detection/resolve/main/crowdresult.jpg' style='width:25%;'/></div>")
            gr.Markdown("## Pose Detection with YOLO11-pose")
            with gr.Tabs():
                with gr.TabItem("Upload File"):
                    file_input = gr.File(label="Upload Image/Video")
                with gr.TabItem("YouTube Link"):
                    youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
                with gr.TabItem("Image URL"):
                    image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
            sensitivity_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5,
                                           label="Sensitivity (Confidence Threshold)")
        # Right Column: Results display at the top.
        with gr.Column(scale=2):
            output_image = gr.Image(label="Annotated Output (Image)", elem_classes="result_img")
            output_video = gr.Video(label="Annotated Output (Video)")
            output_file = gr.File(label="Download Annotated Output")
            output_text = gr.Textbox(label="Status", interactive=False)

    file_input.change(
        fn=process_input,
        inputs=[file_input, gr.State(""), gr.State(""), sensitivity_slider],
        outputs=[output_file, output_image, output_video, output_text]
    )
    youtube_input.change(
        fn=process_input,
        inputs=[gr.State(None), youtube_input, gr.State(""), sensitivity_slider],
        outputs=[output_file, output_image, output_video, output_text]
    )
    image_url_input.change(
        fn=process_input,
        inputs=[gr.State(None), gr.State(""), image_url_input, sensitivity_slider],
        outputs=[output_file, output_image, output_video, output_text]
    )

if __name__ == "__main__":
    demo.launch()