File size: 4,296 Bytes
5260c34
7c90da7
 
5260c34
 
7afb01c
7c90da7
 
7afb01c
5260c34
 
7afb01c
7c90da7
 
7afb01c
7c90da7
7afb01c
bbc9616
 
7c90da7
5260c34
 
7afb01c
7c90da7
 
 
 
 
 
 
bbc9616
7c90da7
 
bbc9616
7afb01c
 
 
 
 
bbc9616
7afb01c
 
 
 
 
bbc9616
7afb01c
7c90da7
 
 
bbc9616
5260c34
bbc9616
7c90da7
5260c34
7c90da7
bbc9616
5260c34
 
 
7afb01c
5260c34
 
 
7afb01c
 
5260c34
 
 
bbc9616
5260c34
bbc9616
7afb01c
7c90da7
bbc9616
 
 
7c90da7
7afb01c
5260c34
7c90da7
7afb01c
 
bbc9616
7c90da7
 
7afb01c
 
 
bbc9616
 
7c90da7
bbc9616
5260c34
7c90da7
bbc9616
 
 
 
 
 
7c90da7
7afb01c
7c90da7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sys
import gradio as gr
import os
import tempfile
import cv2
import requests
from ultralytics import YOLO

# Remove extra CLI arguments (like "--import") from Spaces.
sys.argv = [arg for arg in sys.argv if arg != "--import"]

# Load the YOLO11-pose model (will auto-download if needed)
model = YOLO("yolo11n-pose.pt")

def process_input(uploaded_file, youtube_link, image_url):
    """
    Process an uploaded file, a YouTube link, or an image URL for pose detection.
    Returns a tuple: (download_file_path, display_file_path, status_message).
    Priority: YouTube link > Image URL > Uploaded file.
    """
    input_path = None

    # Priority 1: YouTube link
    if youtube_link and youtube_link.strip():
        try:
            from pytube import YouTube
            yt = YouTube(youtube_link)
            stream = yt.streams.filter(file_extension='mp4', progressive=True)\
                               .order_by("resolution").desc().first()
            if stream is None:
                return None, None, "No suitable mp4 stream found."
            input_path = stream.download()
        except Exception as e:
            return None, None, f"Error downloading video: {e}"
    # Priority 2: Image URL
    elif image_url and image_url.strip():
        try:
            response = requests.get(image_url, stream=True)
            if response.status_code != 200:
                return None, None, f"Error downloading image: HTTP {response.status_code}"
            temp_image_path = os.path.join(tempfile.gettempdir(), "downloaded_image.jpg")
            with open(temp_image_path, "wb") as f:
                f.write(response.content)
            input_path = temp_image_path
        except Exception as e:
            return None, None, f"Error downloading image: {e}"
    # Priority 3: Uploaded file
    elif uploaded_file is not None:
        input_path = uploaded_file.name
    else:
        return None, None, "Please provide a YouTube link, image URL, or upload a file."

    # Run pose detection (with save=True so annotated outputs are written to disk)
    try:
        results = model.predict(source=input_path, save=True)
    except Exception as e:
        return None, None, f"Error running prediction: {e}"

    output_path = None
    try:
        # If the results object has a save_path attribute, use it.
        if hasattr(results[0], "save_path"):
            output_path = results[0].save_path
        else:
            # Otherwise, generate an annotated image using plot() and save it manually.
            annotated = results[0].plot()  # returns a numpy array
            output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
            cv2.imwrite(output_path, annotated)
    except Exception as e:
        return None, None, f"Error processing the file: {e}"

    # Clean up the temporary input file if downloaded.
    if (youtube_link or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
        os.remove(input_path)
    
    # Return the same output path for both download and display.
    return output_path, output_path, "Success!"

# Define the Gradio interface.
with gr.Blocks() as demo:
    gr.Markdown("# Pose Detection with YOLO11-pose")
    gr.Image(value="crowdresult.jpg", label="Crowd Result", interactive=False)
    gr.Markdown("Upload an image/video, provide an image URL, or supply a YouTube link to detect human poses.")
    
    with gr.Row():
        file_input = gr.File(label="Upload Image/Video")
    with gr.Row():
        youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
        image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
    
    # Three outputs: one for file download, one for immediate display, and one for status text.
    output_file = gr.File(label="Download Annotated Output")
    output_display = gr.Image(label="Annotated Output")
    output_text = gr.Textbox(label="Status", interactive=False)
    run_button = gr.Button("Run Pose Detection")
    
    run_button.click(
        process_input,
        inputs=[file_input, youtube_input, image_url_input],
        outputs=[output_file, output_display, output_text]
    )

# Only launch the interface if executed directly.
if __name__ == "__main__":
    demo.launch()