File size: 4,146 Bytes
5260c34
7c90da7
 
5260c34
 
7afb01c
7c90da7
 
7afb01c
5260c34
 
7afb01c
7c90da7
 
7afb01c
7c90da7
7afb01c
 
 
7c90da7
5260c34
 
7afb01c
7c90da7
 
 
 
 
 
 
5260c34
7c90da7
 
5260c34
7afb01c
 
 
 
 
 
 
 
 
 
 
 
 
 
7c90da7
 
 
7afb01c
5260c34
7afb01c
7c90da7
5260c34
7c90da7
5260c34
 
 
 
7afb01c
5260c34
 
 
7afb01c
 
5260c34
 
 
 
 
7afb01c
 
7c90da7
7afb01c
5260c34
7c90da7
7afb01c
5260c34
7afb01c
7c90da7
7afb01c
 
 
7c90da7
 
7afb01c
 
 
 
7c90da7
5260c34
7c90da7
7afb01c
 
5260c34
7c90da7
7afb01c
7c90da7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import sys
import gradio as gr
import os
import tempfile
import cv2
import requests
from ultralytics import YOLO

# Remove extra CLI arguments (like "--import") from Spaces.
sys.argv = [arg for arg in sys.argv if arg != "--import"]

# Load the YOLO11-pose model (will auto-download if needed)
model = YOLO("yolo11n-pose.pt")

def process_input(uploaded_file, youtube_link, image_url):
    """
    Process an uploaded file, a YouTube link, or an image URL for pose detection.
    Returns a tuple: (annotated_file_path, status_message).  
    Priority is given to the YouTube link first, then image URL, then the uploaded file.
    """
    input_path = None

    # Priority 1: YouTube link
    if youtube_link and youtube_link.strip():
        try:
            from pytube import YouTube
            yt = YouTube(youtube_link)
            stream = yt.streams.filter(file_extension='mp4', progressive=True)\
                               .order_by("resolution").desc().first()
            if stream is None:
                return None, "No suitable mp4 stream found."
            input_path = stream.download()
        except Exception as e:
            return None, f"Error downloading video: {e}"
    # Priority 2: Image URL
    elif image_url and image_url.strip():
        try:
            response = requests.get(image_url, stream=True)
            if response.status_code != 200:
                return None, f"Error downloading image: HTTP {response.status_code}"
            # Save the downloaded image to a temporary file.
            temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
            with open(temp_image_path, "wb") as f:
                f.write(response.content)
            input_path = temp_image_path
        except Exception as e:
            return None, f"Error downloading image: {e}"
    # Priority 3: Uploaded file
    elif uploaded_file is not None:
        input_path = uploaded_file.name
    else:
        return None, "Please provide a YouTube link, image URL, or upload a file."

    # Run pose detection (with save=True so annotated results are written to disk)
    try:
        results = model.predict(source=input_path, save=True)
    except Exception as e:
        return None, f"Error running prediction: {e}"

    output_path = None
    try:
        # If the results object has a save_path attribute, use it.
        if hasattr(results[0], "save_path"):
            output_path = results[0].save_path
        else:
            # Otherwise, generate an annotated image using plot() and save it manually.
            annotated = results[0].plot()  # returns a numpy array
            output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
            cv2.imwrite(output_path, annotated)
    except Exception as e:
        return None, f"Error processing the file: {e}"

    # If the input came from YouTube or image URL, remove the temporary file.
    if (youtube_link or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
        os.remove(input_path)

    return output_path, "Success!"

# Define the Gradio interface.
with gr.Blocks() as demo:
    # Display the default image (crowdresult.jpg) at the top.
    gr.Markdown("# Pose Detection with YOLO11-pose")
    gr.Image(value="crowdresult.jpg", label="Crowd Result", interactive=False)
    gr.Markdown("Upload an image/video, provide an image URL, or supply a YouTube link to detect human poses.")

    with gr.Row():
        file_input = gr.File(label="Upload Image/Video")
    with gr.Row():
        youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
        image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")

    output_file = gr.File(label="Download Annotated Output")
    output_text = gr.Textbox(label="Status", interactive=False)
    run_button = gr.Button("Run Pose Detection")

    run_button.click(process_input, inputs=[file_input, youtube_input, image_url_input],
                     outputs=[output_file, output_text])

# Only launch the interface if executed directly.
if __name__ == "__main__":
    demo.launch()