File size: 5,301 Bytes
5260c34
7c90da7
 
5260c34
 
7afb01c
7c90da7
 
fc38e58
5260c34
 
ef829ca
7c90da7
 
fc38e58
7c90da7
fc38e58
 
 
 
ef829ca
fc38e58
 
7c90da7
5260c34
 
7afb01c
7c90da7
 
 
 
 
 
 
fc38e58
7c90da7
 
fc38e58
7afb01c
 
 
 
 
fc38e58
7afb01c
 
 
 
 
fc38e58
7afb01c
7c90da7
 
 
fc38e58
5260c34
7c90da7
fc38e58
 
7c90da7
fc38e58
5260c34
 
 
 
 
 
7afb01c
5260c34
 
 
fc38e58
5260c34
fc38e58
 
7c90da7
bbc9616
fc38e58
7c90da7
fc38e58
 
 
 
 
 
 
 
 
6ab3263
ef829ca
fc38e58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef829ca
7c90da7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sys
import gradio as gr
import os
import tempfile
import cv2
import requests
from ultralytics import YOLO

# Remove extra CLI arguments that Spaces might pass.
sys.argv = [arg for arg in sys.argv if arg != "--import"]

# Load the YOLO11-pose model (auto-downloads if needed)
model = YOLO("yolo11n-pose.pt")

def process_input(uploaded_file, youtube_link, image_url, sensitivity):
    """
    Process input from one of the three methods (Upload, YouTube, Image URL).
    Priority: YouTube link > Image URL > Uploaded file.
    The sensitivity slider value is passed as the confidence threshold.
    
    Returns a tuple:
      (download_file_path, display_file_path, status_message, dummy_state)
    (The dummy_state is used because Gradio requires the same number of outputs.)
    """
    input_path = None

    # Priority 1: YouTube link
    if youtube_link and youtube_link.strip():
        try:
            from pytube import YouTube
            yt = YouTube(youtube_link)
            stream = yt.streams.filter(file_extension='mp4', progressive=True)\
                               .order_by("resolution").desc().first()
            if stream is None:
                return None, None, "No suitable mp4 stream found.", ""
            input_path = stream.download()
        except Exception as e:
            return None, None, f"Error downloading video: {e}", ""
    # Priority 2: Image URL
    elif image_url and image_url.strip():
        try:
            response = requests.get(image_url, stream=True)
            if response.status_code != 200:
                return None, None, f"Error downloading image: HTTP {response.status_code}", ""
            temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
            with open(temp_image_path, "wb") as f:
                f.write(response.content)
            input_path = temp_image_path
        except Exception as e:
            return None, None, f"Error downloading image: {e}", ""
    # Priority 3: Uploaded file
    elif uploaded_file is not None:
        input_path = uploaded_file.name
    else:
        return None, None, "Please provide an input using one of the methods.", ""

    try:
        # Pass the slider value as the confidence threshold.
        results = model.predict(source=input_path, save=True, conf=sensitivity)
    except Exception as e:
        return None, None, f"Error running prediction: {e}", ""

    output_path = None
    try:
        if hasattr(results[0], "save_path"):
            output_path = results[0].save_path
        else:
            annotated = results[0].plot()  # returns a numpy array
            output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
            cv2.imwrite(output_path, annotated)
    except Exception as e:
        return None, None, f"Error processing the file: {e}", ""

    # Clean up the temporary input if it was downloaded.
    if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
        os.remove(input_path)
    
    return output_path, output_path, "Success!", ""

# Build the Gradio interface with custom CSS for the result image.
with gr.Blocks(css="""
.result_img > img {
  width: 100%;
  height: auto;
  object-fit: contain;
}
""") as demo:
    # Header with scaled image (25% width) and title.
gr.HTML("<div style='text-align:center;'><img src='/crowdresult.jpg' style='width:25%;'/></div>")
    gr.Markdown("## Pose Detection with YOLO11-pose")
    
    # Create two columns.
    with gr.Row():
        # Left column: Input tabs and sensitivity slider.
        with gr.Column(scale=1):
            with gr.Tabs():
                with gr.TabItem("Upload File"):
                    file_input = gr.File(label="Upload Image/Video")
                with gr.TabItem("YouTube Link"):
                    youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
                with gr.TabItem("Image URL"):
                    image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
            sensitivity_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.5,
                                           label="Sensitivity (Confidence Threshold)")
        # Right column: Display result.
        with gr.Column(scale=2):
            output_display = gr.Image(label="Annotated Output", elem_classes="result_img")
            output_file = gr.File(label="Download Annotated Output")
            output_text = gr.Textbox(label="Status", interactive=False)
    
    # Set up automatic triggers for each input type.
    file_input.change(
        fn=process_input,
        inputs=[file_input, gr.State(""), gr.State(""), sensitivity_slider],
        outputs=[output_file, output_display, output_text, gr.State()]
    )
    youtube_input.change(
        fn=process_input,
        inputs=[gr.State(None), youtube_input, gr.State(""), sensitivity_slider],
        outputs=[output_file, output_display, output_text, gr.State()]
    )
    image_url_input.change(
        fn=process_input,
        inputs=[gr.State(None), gr.State(""), image_url_input, sensitivity_slider],
        outputs=[output_file, output_display, output_text, gr.State()]
    )

if __name__ == "__main__":
    demo.launch()