File size: 6,745 Bytes
5260c34
7c90da7
 
5260c34
 
7afb01c
7c90da7
 
7aa805e
5260c34
 
7aa805e
7c90da7
 
fc38e58
7c90da7
7aa805e
fc38e58
7aa805e
7c90da7
5260c34
7aa805e
5260c34
7afb01c
7c90da7
 
7aa805e
7c90da7
f34aee4
7aa805e
eadcbdc
7aa805e
 
 
 
7c90da7
7aa805e
 
eadcbdc
 
 
7aa805e
 
 
 
eadcbdc
7aa805e
 
eadcbdc
7aa805e
 
eadcbdc
 
 
 
7aa805e
eadcbdc
7aa805e
 
680b8e8
eadcbdc
 
7aa805e
 
 
680b8e8
 
 
7aa805e
680b8e8
 
 
7aa805e
 
 
 
 
 
30fd2ce
 
 
 
7aa805e
 
 
 
 
 
 
 
 
30fd2ce
680b8e8
7aa805e
 
 
 
 
 
680b8e8
7aa805e
 
 
 
 
 
 
 
eadcbdc
7aa805e
 
eadcbdc
7aa805e
 
 
 
 
 
 
 
eadcbdc
7aa805e
eadcbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aa805e
 
eadcbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c90da7
7aa805e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import sys
import gradio as gr
import os
import tempfile
import cv2
import requests
from ultralytics import YOLO

# Remove extra CLI arguments that Spaces might pass
sys.argv = [arg for arg in sys.argv if arg != "--import"]

# Load the YOLO11-pose model
model = YOLO("yolo11n-pose.pt")

def process_input(uploaded_file, youtube_link, image_url, sensitivity):
    """
    Process input from Upload, YouTube, or Image URL.
    Priority: YouTube link > Image URL > Uploaded file.
    Sensitivity is the confidence threshold.
    """
    input_path = None
    temp_files = []

    # Priority 1: YouTube link
    if youtube_link and youtube_link.strip():
        try:
            from pytubefix import YouTube  # Use pytubefix instead of pytube
            yt = YouTube(youtube_link)
            stream = yt.streams.filter(file_extension='mp4', progressive=True).order_by("resolution").desc().first()
            if not stream:
                return None, None, None, "No suitable mp4 stream found."
            temp_path = os.path.join(tempfile.gettempdir(), f"yt_{os.urandom(8).hex()}.mp4")
            stream.download(output_path=tempfile.gettempdir(), filename=os.path.basename(temp_path))
            input_path = temp_path
            temp_files.append(input_path)
        except Exception as e:
            return None, None, None, f"Error downloading YouTube video: {str(e)}"

    # Priority 2: Image URL
    elif image_url and image_url.strip():
        try:
            response = requests.get(image_url, stream=True, timeout=10)
            response.raise_for_status()
            temp_path = os.path.join(tempfile.gettempdir(), f"img_{os.urandom(8).hex()}.jpg")
            with open(temp_path, "wb") as f:
                f.write(response.content)
            input_path = temp_path
            temp_files.append(input_path)
        except Exception as e:
            return None, None, None, f"Error downloading image: {str(e)}"

    # Priority 3: Uploaded file
    elif uploaded_file is not None:
        input_path = uploaded_file.name
    else:
        return None, None, None, "Please provide an input."

    # Process the file
    ext = os.path.splitext(input_path)[1].lower()
    video_exts = [".mp4", ".mov", ".avi", ".webm"]
    output_path = None

    try:
        if ext in video_exts:
            # Video processing
            cap = cv2.VideoCapture(input_path)
            if not cap.isOpened():
                return None, None, None, "Error opening video file."
            
            fps = cap.get(cv2.CAP_PROP_FPS)
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            
            # Create output video
            output_path = os.path.join(tempfile.gettempdir(), f"out_{os.urandom(8).hex()}.mp4")
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
            
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                    
                # Convert BGR to RGB for YOLO
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                results = model.predict(source=frame_rgb, conf=sensitivity)[0]
                annotated_frame = results.plot()
                # Convert back to BGR for video writing
                annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
                out.write(annotated_frame_bgr)
            
            cap.release()
            out.release()
            temp_files.append(output_path)
            
            if os.path.getsize(output_path) == 0:
                return None, None, None, "Error: Output video is empty."
                
            return output_path, None, output_path, "Video processed successfully!"

        else:
            # Image processing
            results = model.predict(source=input_path, conf=sensitivity)[0]
            annotated = results.plot()
            output_path = os.path.join(tempfile.gettempdir(), f"out_{os.urandom(8).hex()}.jpg")
            cv2.imwrite(output_path, annotated)
            temp_files.append(output_path)
            return output_path, output_path, None, "Image processed successfully!"

    except Exception as e:
        return None, None, None, f"Processing error: {str(e)}"

    finally:
        # Clean up temporary input files (but keep output for download)
        for f in temp_files[:-1]:  # Exclude output_path
            if f and os.path.exists(f):
                try:
                    os.remove(f)
                except:
                    pass

# Gradio interface remains mostly the same
with gr.Blocks(css="""
.result_img > img {
  width: 100%;
  height: auto;
  object-fit: contain;
}
""") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.HTML("<div style='text-align:center;'><img src='https://huggingface.co/spaces/tstone87/stance-detection/resolve/main/crowdresult.jpg' style='width:25%;'/></div>")
            gr.Markdown("## Pose Detection with YOLO11-pose")
            with gr.Tabs():
                with gr.TabItem("Upload File"):
                    file_input = gr.File(label="Upload Image/Video")
                with gr.TabItem("YouTube Link"):
                    youtube_input = gr.Textbox(label="YouTube Link", placeholder="https://...")
                with gr.TabItem("Image URL"):
                    image_url_input = gr.Textbox(label="Image URL", placeholder="https://...")
            sensitivity_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2,
                                         label="Sensitivity (Confidence Threshold)")
        with gr.Column(scale=2):
            output_image = gr.Image(label="Annotated Output (Image)", elem_classes="result_img")
            output_video = gr.Video(label="Annotated Output (Video)")
            output_file = gr.File(label="Download Annotated Output")
            output_text = gr.Textbox(label="Status", interactive=False)

    file_input.change(
        fn=process_input,
        inputs=[file_input, gr.State(""), gr.State(""), sensitivity_slider],
        outputs=[output_file, output_image, output_video, output_text]
    )
    youtube_input.change(
        fn=process_input,
        inputs=[gr.State(None), youtube_input, gr.State(""), sensitivity_slider],
        outputs=[output_file, output_image, output_video, output_text]
    )
    image_url_input.change(
        fn=process_input,
        inputs=[gr.State(None), gr.State(""), image_url_input, sensitivity_slider],
        outputs=[output_file, output_image, output_video, output_text]
    )

if __name__ == "__main__":
    demo.launch()