File size: 7,750 Bytes
b7c7aa0
 
dc465b0
b7c7aa0
ee6d0d7
b0844d7
 
b7c7aa0
ee6d0d7
31c66a7
dc465b0
ee6d0d7
98d1f42
ee6d0d7
 
b7c7aa0
ee6d0d7
98d1f42
ee6d0d7
 
 
 
 
 
 
b7c7aa0
ee6d0d7
 
 
 
 
 
 
98d1f42
ee6d0d7
 
d250946
ee6d0d7
 
210435d
98d1f42
210435d
 
6b6240e
210435d
 
 
 
 
 
 
ee6d0d7
 
 
 
 
 
 
 
 
 
 
b7c7aa0
 
 
 
 
 
 
 
 
 
 
dc465b0
b7c7aa0
 
dc465b0
 
 
 
 
 
50e7d4a
 
 
 
 
 
 
 
 
 
b7c7aa0
 
 
 
 
 
 
 
50e7d4a
 
1416e63
 
50e7d4a
 
 
 
 
 
 
 
b7c7aa0
 
 
50e7d4a
 
1416e63
 
50e7d4a
1416e63
 
50e7d4a
 
1416e63
b7c7aa0
 
dc465b0
ee6d0d7
b7c7aa0
 
 
 
 
50e7d4a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# app.py
import gradio as gr
from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI
from constraint import SYS_PROMPT, USER_PROMPT
from datasets import load_dataset
import tempfile
import requests

def load_hf_dataset(dataset_path, auth_token):
    dataset = load_dataset(dataset_path, token=auth_token)
    
    video_paths = dataset
    print("load done")
    
    return video_paths

def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
    print("begin caption")
    if video_src:
        video = video_src
        processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
        frames = processor._decode(video)
        
        base64_list = processor.to_base64_list(frames)
        debug_image = processor.concatenate(frames)
    
        if not key or not endpoint:
            return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
        
        api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
        caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
        return f"{caption}", f"Using model '{model}' with {len(frames)} frames extracted.", debug_image
    elif video_hf and video_hf_auth:
        print("begin video_hf")
        # Handle Hugging Face dataset
        video_paths = load_hf_dataset(video_hf, video_hf_auth)
        video_paths = video_paths["train"]
        # Process all videos in the dataset
        all_captions = []
        for video_path_url in video_paths:
            print("video_path")
            video_path_url = video_path_url["id"]
            # 使用requests下载文件到临时文件
            response = requests.get(video_path_url, stream=True)
            if response.status_code == 200:
                with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
                    temp_video_file.write(response.content)
                    video_path = temp_video_file.name
            else:
                raise Exception(f"Failed to download video, status code: {response.status_code}")
                
            if video_path.endswith('.mp4'):  # 假设我们只处理.mp4文件
                processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
                frames = processor._decode(video_path)
                base64_list = processor.to_base64_list(frames)
                api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
                caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
                all_captions.append(caption)
        return "\n".join(all_captions), f"Processed {len(video_paths)} videos.", None
    # ... (Handle other sources)
    else:
        return "", "No video source selected.", None

with gr.Blocks() as Core:
    with gr.Row(variant="panel"):
        with gr.Column(scale=6):
            with gr.Accordion("Debug", open=False):
                info = gr.Textbox(label="Info", interactive=False)
                frame = gr.Image(label="Frame", interactive=False)
            with gr.Accordion("Configuration", open=False):
                with gr.Row():
                    temp = gr.Slider(0, 1, 0.3, step=0.1, label="Temperature")
                    top_p = gr.Slider(0, 1, 0.75, step=0.1, label="Top-P")
                    max_tokens = gr.Slider(512, 4096, 1024, step=1, label="Max Tokens")
                with gr.Row():
                    frame_format = gr.Dropdown(label="Frame Format", value="JPEG", choices=["JPEG", "PNG"], interactive=False)
                    frame_limit = gr.Slider(1, 100, 10, step=1, label="Frame Limits")
            with gr.Tabs():
                with gr.Tab("User"):
                    usr_prompt = gr.Textbox(USER_PROMPT, label="User Prompt", lines=10, max_lines=100, show_copy_button=True)
                with gr.Tab("System"):
                    sys_prompt = gr.Textbox(SYS_PROMPT, label="System Prompt", lines=10, max_lines=100, show_copy_button=True)
            with gr.Tabs():
                with gr.Tab("Azure"):
                    result = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
                with gr.Tab("Google"):
                    result_gg = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
                with gr.Tab("Anthropic"):
                    result_ac = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
                with gr.Tab("OpenAI"):
                    result_oai = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)

        with gr.Column(scale=2):
            with gr.Column():
                with gr.Accordion("Model Provider", open=True):
                    with gr.Tabs():
                        with gr.Tab("Azure"):
                            model = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False)
                            key = gr.Textbox(label="Azure API Key")
                            endpoint = gr.Textbox(label="Azure Endpoint")
                        with gr.Tab("Google"):
                            model_gg = gr.Dropdown(label="Model", value="Gemini-1.5-Flash", choices=["Gemini-1.5-Flash", "Gemini-1.5-Pro"], interactive=False)
                            key_gg = gr.Textbox(label="Gemini API Key")
                            endpoint_gg = gr.Textbox(label="Gemini API Endpoint")
                        with gr.Tab("Anthropic"):
                            model_ac = gr.Dropdown(label="Model", value="Claude-3-Opus", choices=["Claude-3-Opus", "Claude-3-Sonnet"], interactive=False)
                            key_ac = gr.Textbox(label="Anthropic API Key")
                            endpoint_ac = gr.Textbox(label="Anthropic Endpoint")
                        with gr.Tab("OpenAI"):
                            model_oai = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False)
                            key_oai = gr.Textbox(label="OpenAI API Key")
                            endpoint_oai = gr.Textbox(label="OpenAI Endpoint")
                with gr.Accordion("Data Source", open=True):
                    with gr.Tabs():
                        with gr.Tab("Upload"):
                            video_src = gr.Video(sources="upload", show_label=False, show_share_button=False, mirror_webcam=False)
                        with gr.Tab("HF"):
                            video_hf = gr.Text(label="Huggingface File Path")
                            video_hf_auth = gr.Text(label="Huggingface Token")
                        with gr.Tab("Onedrive"):
                            video_od = gr.Text("Microsoft Onedrive")
                            video_od_auth = gr.Text(label="Microsoft Onedrive Token")
                        with gr.Tab("Google Drive"):
                            video_gd = gr.Text()
                            video_gd_auth = gr.Text(label="Google Drive Access Token")
                caption_button = gr.Button("Caption", variant="primary", size="lg")
        caption_button.click(
            fast_caption, 
            inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit], 
            outputs=[result, info, frame]
        )

if __name__ == "__main__":
    Core.launch()