File size: 3,927 Bytes
94f04b7
abe2204
94f04b7
 
783db6b
 
94f04b7
46a60b0
 
94f04b7
afe246e
 
 
 
783db6b
 
46a60b0
 
 
 
 
 
 
783db6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3883a
783db6b
afe246e
 
 
 
46a60b0
783db6b
46a60b0
afe246e
46a60b0
 
afe246e
783db6b
afe246e
 
46a60b0
afe246e
 
46a60b0
783db6b
afe246e
 
783db6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afe246e
783db6b
 
94f04b7
783db6b
 
 
94f04b7
 
abe2204
783db6b
 
 
 
 
 
94f04b7
783db6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import spaces

from inference.seg import process_image_or_video
from config import SAPIENS_LITE_MODELS_PATH

def update_model_choices(task):
    model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
    return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)

@spaces.GPU(duration=120)
def process_image(input_image, task, version):
    if isinstance(input_image, np.ndarray):
        input_image = Image.fromarray(input_image)
    
    result = process_image_or_video(input_image, task=task.lower(), version=version)
    
    return result

def process_video(input_video, task, version):
    cap = cv2.VideoCapture(input_video)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    output_video = cv2.VideoWriter('output_video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        processed_frame = process_image_or_video(frame_rgb, task=task.lower(), version=version)
        
        if processed_frame is not None:
            processed_frame_bgr = cv2.cvtColor(np.array(processed_frame), cv2.COLOR_RGB2BGR)
            output_video.write(processed_frame_bgr)
    
    cap.release()
    output_video.release()
    
    return 'output_video.mp4'

with gr.Blocks() as demo:
    gr.Markdown("# Sapiens Arena 🤸🏽‍♂️ - WIP devmode")
    with gr.Tabs():
        with gr.TabItem('Image'):
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="Input Image", type="pil")
                    select_task_image = gr.Radio(
                        ["seg", "pose", "depth", "normal"], 
                        label="Task", 
                        info="Choose the task to perform",
                        value="seg"
                    )
                    model_name_image = gr.Dropdown(
                        label="Model Version",
                        choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()),
                        value="sapiens_0.3b",
                    )
                with gr.Column():
                    result_image = gr.Image(label="Result")
                    run_button_image = gr.Button("Run")
        
        with gr.TabItem('Video'):
            with gr.Row():
                with gr.Column():
                    input_video = gr.Video(label="Input Video")
                    select_task_video = gr.Radio(
                        ["seg", "pose", "depth", "normal"], 
                        label="Task", 
                        info="Choose the task to perform",
                        value="seg"
                    )
                    model_name_video = gr.Dropdown(
                        label="Model Version",
                        choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()),
                        value="sapiens_0.3b",
                    )
                with gr.Column():
                    result_video = gr.Video(label="Result")
                    run_button_video = gr.Button("Run")

    select_task_image.change(fn=update_model_choices, inputs=select_task_image, outputs=model_name_image)
    select_task_video.change(fn=update_model_choices, inputs=select_task_video, outputs=model_name_video)

    run_button_image.click(
        fn=process_image,
        inputs=[input_image, select_task_image, model_name_image],
        outputs=[result_image],
    )

    run_button_video.click(
        fn=process_video,
        inputs=[input_video, select_task_video, model_name_video],
        outputs=[result_video],
    )

if __name__ == "__main__":
    demo.launch(share=False)