File size: 3,419 Bytes
5260c34
7c90da7
 
5260c34
 
7c90da7
 
5260c34
 
 
 
7c90da7
 
 
 
 
5260c34
 
7c90da7
5260c34
 
 
 
7c90da7
 
 
 
5260c34
7c90da7
 
 
5260c34
7c90da7
 
5260c34
7c90da7
 
 
5260c34
 
 
7c90da7
5260c34
7c90da7
5260c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c90da7
 
5260c34
7c90da7
5260c34
 
7c90da7
 
5260c34
7c90da7
 
 
5260c34
7c90da7
5260c34
7c90da7
5260c34
 
 
7c90da7
5260c34
7c90da7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import sys
import gradio as gr
import os
import tempfile
import cv2
from ultralytics import YOLO

# Optionally remove extra CLI arguments that Spaces might pass
sys.argv = [arg for arg in sys.argv if arg != "--import"]

# Load the YOLO11-pose model (it will auto-download if not present)
model = YOLO("yolo11n-pose.pt")

def process_input(uploaded_file, youtube_link):
    """
    Process an uploaded file or a YouTube link to perform pose detection.
    Returns a tuple: (annotated_file_path, status_message). 
    If an error occurs, annotated_file_path is None and status_message describes the error.
    """
    error_message = ""
    input_path = None

    # Check for input: either a YouTube link or an uploaded file.
    if youtube_link and youtube_link.strip():
        try:
            from pytube import YouTube
            yt = YouTube(youtube_link)
            # Get the highest resolution progressive mp4 stream
            stream = yt.streams.filter(file_extension='mp4', progressive=True)\
                               .order_by("resolution").desc().first()
            if stream is None:
                return None, "No suitable mp4 stream found."
            input_path = stream.download()
        except Exception as e:
            return None, f"Error downloading video: {e}"
    elif uploaded_file is not None:
        input_path = uploaded_file.name
    else:
        return None, "Please provide an uploaded file or a YouTube link."

    # Run pose detection (with save=True so that outputs are written to disk)
    try:
        results = model.predict(source=input_path, save=True)
    except Exception as e:
        return None, f"Error running prediction: {e}"

    # Try to get the annotated output file:
    output_path = None
    try:
        # Some YOLO versions may offer a 'save_path' attribute.
        if hasattr(results[0], "save_path"):
            output_path = results[0].save_path
        else:
            # Fallback: generate the annotated image using result.plot()
            annotated = results[0].plot()  # returns a numpy array with annotations
            # Save the annotated image to a temporary file.
            output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
            cv2.imwrite(output_path, annotated)
    except Exception as e:
        return None, f"Error processing the file: {e}"

    # Clean up the downloaded video file if it came from YouTube.
    if youtube_link and input_path and os.path.exists(input_path):
        os.remove(input_path)
    
    return output_path, "Success!"

# Define the Gradio interface with two outputs: one for the file and one for status text.
with gr.Blocks() as demo:
    gr.Markdown("# Pose Detection with YOLO11-pose")
    gr.Markdown("Upload an image/video or provide a YouTube link to detect human poses.")
    
    with gr.Row():
        file_input = gr.File(label="Upload Image/Video")
        youtube_input = gr.Textbox(label="Or enter a YouTube link", placeholder="https://...")
    
    output_file = gr.File(label="Download Annotated Output")
    output_text = gr.Textbox(label="Status", interactive=False)
    run_button = gr.Button("Run Pose Detection")
    
    run_button.click(process_input, inputs=[file_input, youtube_input],
                     outputs=[output_file, output_text])

# Only launch the app if the script is executed directly.
if __name__ == "__main__":
    demo.launch()