File size: 5,236 Bytes
152a369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b41cec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152a369
ce8aef9
152a369
 
b41cec1
 
 
152a369
 
 
b41cec1
 
 
 
 
 
 
 
 
 
 
 
 
 
152a369
b41cec1
 
 
 
 
152a369
b41cec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152a369
 
b41cec1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import pixeltable as pxt
from pixeltable.functions.huggingface import clip_image, clip_text
from pixeltable.iterators import FrameIterator
import PIL.Image
import os

# Embedding functions
@pxt.expr_udf
def embed_image(img: PIL.Image.Image):
    return clip_image(img, model_id='openai/clip-vit-base-patch32')

@pxt.expr_udf
def str_embed(s: str):
    return clip_text(s, model_id='openai/clip-vit-base-patch32')

# Process video and create index
def process_video(video_file, progress=gr.Progress()):

    progress(0, desc="Initializing...")

    # Pixeltable setup
    pxt.drop_dir('video_search', force=True)
    pxt.create_dir('video_search')

    video_table = pxt.create_table('video_search.videos', {'video': pxt.VideoType()})

    frames_view = pxt.create_view(
        'video_search.frames', 
        video_table, 
        iterator=FrameIterator.create(video=video_table.video, fps=1)
    )

    progress(0.2, desc="Inserting video...")
    video_table.insert([{'video': video_file.name}])
    
    progress(0.4, desc="Creating embedding index...")
    frames_view.add_embedding_index('frame', string_embed=str_embed, image_embed=embed_image)

    progress(1.0, desc="Processing complete")
    return "Video processed and indexed successfully!"

# Perform similarity search
def similarity_search(query, search_type, num_results, progress=gr.Progress()):
    
    frames_view = pxt.get_table('video_search.frames')
    
    progress(0.5, desc="Performing search...")
    if search_type == "Text":
        sim = frames_view.frame.similarity(query)
    else:  # Image search
        sim = frames_view.frame.similarity(query)
    
    results = frames_view.order_by(sim, asc=False).limit(num_results).select(frames_view.frame, sim=sim).collect()
    
    progress(1.0, desc="Search complete")

    return [row['frame'] for row in results]

# Process video and create index
def process_video(video_file, progress=gr.Progress()):

    progress(0, desc="Initializing...")

    # Pixeltable setup
    pxt.drop_dir('video_search', force=True)
    pxt.create_dir('video_search')

    video_table = pxt.create_table('video_search.videos', {'video': pxt.VideoType()})

    frames_view = pxt.create_view(
        'video_search.frames', 
        video_table, 
        iterator=FrameIterator.create(video=video_table.video, fps=1)
    )

    progress(0.2, desc="Inserting video...")
    video_table.insert([{'video': video_file.name}])
    
    progress(0.4, desc="Creating embedding index...")
    frames_view.add_embedding_index('frame', string_embed=str_embed, image_embed=embed_image)

    progress(1.0, desc="Processing complete")
    return "Video processed and indexed successfully!"

# Perform similarity search
def similarity_search(query, search_type, num_results, progress=gr.Progress()):
    
    frames_view = pxt.get_table('video_search.frames')
    
    progress(0.5, desc="Performing search...")
    if search_type == "Text":
        sim = frames_view.frame.similarity(query)
    else:  # Image search
        sim = frames_view.frame.similarity(query)
    
    results = frames_view.order_by(sim, asc=False).limit(num_results).select(frames_view.frame, sim=sim).collect()
    
    progress(1.0, desc="Search complete")

    return [row['frame'] for row in results]

# Gradio interface
with gr.Blocks(theme=gr.themes.Base()) as demo:
    gr.Markdown(
        """
        <div style= margin-bottom: 20px;">
            <img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/source/data/pixeltable-logo-large.png" alt="Pixeltable" style="max-width: 150px;" />
            <h2>Video Frame Search with Pixeltable</h2>
        </div>
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            video_file = gr.File(label="Upload Video")
            process_button = gr.Button("Process Video")
            process_output = gr.Textbox(label="Status", lines=2)
            
            gr.Markdown("---")
            
            search_type = gr.Radio(["Text", "Image"], label="Search Type", value="Text")
            text_input = gr.Textbox(label="Text Query")
            image_input = gr.Image(label="Image Query", type="pil", visible=False)
            num_results = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of Results")
            search_button = gr.Button("Search")
        
        with gr.Column(scale=2):
            results_gallery = gr.Gallery(label="Search Results", columns=3)
    
    def update_search_input(choice):
        return gr.update(visible=choice=="Text"), gr.update(visible=choice=="Image")

    search_type.change(update_search_input, search_type, [text_input, image_input])
    
    process_button.click(
        process_video,
        inputs=[video_file],
        outputs=[process_output]
    )
    
    def perform_search(search_type, text_query, image_query, num_results):
        query = text_query if search_type == "Text" else image_query
        return similarity_search(query, search_type, num_results)

    search_button.click(
        perform_search,
        inputs=[search_type, text_input, image_input, num_results],
        outputs=[results_gallery]
    )

if __name__ == "__main__":
    demo.launch()