|
import gradio as gr |
|
import pixeltable as pxt |
|
from pixeltable.functions.huggingface import clip_image, clip_text |
|
from pixeltable.iterators import FrameIterator |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
PROJECT_DIR = 'video_search' |
|
VIDEOS_TABLE = f'{PROJECT_DIR}.videos' |
|
FRAMES_VIEW = f'{PROJECT_DIR}.frames' |
|
|
|
|
|
def process_video(video_file, progress=gr.Progress()): |
|
if video_file is None: |
|
return "Please upload a video file first." |
|
|
|
try: |
|
progress(0, desc="Initializing...") |
|
logger.info(f"Processing video: {video_file.name}") |
|
|
|
|
|
pxt.drop_dir(PROJECT_DIR, force=True) |
|
pxt.create_dir(PROJECT_DIR) |
|
|
|
|
|
video_table = pxt.create_table(VIDEOS_TABLE, {'video': pxt.Video}) |
|
|
|
|
|
frames_view = pxt.create_view( |
|
FRAMES_VIEW, |
|
video_table, |
|
iterator=FrameIterator.create(video=video_table.video, fps=1) |
|
) |
|
|
|
progress(0.2, desc="Inserting video...") |
|
video_table.insert([{'video': video_file.name}]) |
|
|
|
progress(0.4, desc="Creating embedding index...") |
|
|
|
clip_model = 'openai/clip-vit-base-patch32' |
|
frames_view.add_embedding_index( |
|
'frame', |
|
string_embed=clip_text.using(model_id=clip_model), |
|
image_embed=clip_image.using(model_id=clip_model) |
|
) |
|
|
|
progress(1.0, desc="Processing complete") |
|
return "✅ Video processed successfully! You can now search for specific moments using text or images." |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing video: {str(e)}") |
|
return f"Error processing video: {str(e)}" |
|
|
|
|
|
def similarity_search(query, search_type, num_results, progress=gr.Progress()): |
|
try: |
|
if not query: |
|
return [] |
|
|
|
frames_view = pxt.get_table(FRAMES_VIEW) |
|
if frames_view is None: |
|
return [] |
|
|
|
progress(0.5, desc="Performing search...") |
|
sim = frames_view.frame.similarity(query) |
|
|
|
results = frames_view.order_by(sim, asc=False).limit(num_results).select( |
|
frames_view.frame, |
|
similarity=sim |
|
).collect() |
|
|
|
progress(1.0, desc="Search complete") |
|
return [row['frame'] for row in results] |
|
|
|
except Exception as e: |
|
logger.error(f"Error during search: {str(e)}") |
|
return [] |
|
|
|
|
|
css = """ |
|
.container { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
} |
|
.header { |
|
display: flex; |
|
align-items: center; |
|
margin-bottom: 20px; |
|
} |
|
.header img { |
|
max-width: 120px; |
|
margin-right: 20px; |
|
} |
|
.step-header { |
|
background-color: #f5f5f5; |
|
padding: 10px; |
|
border-radius: 5px; |
|
margin-bottom: 15px; |
|
} |
|
.examples-section { |
|
margin-top: 30px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
gr.HTML( |
|
""" |
|
<div class="header"> |
|
<img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/resources/pixeltable-logo-large.png" alt="Pixeltable" /> |
|
<div> |
|
<h1>Video Frame Search with AI</h1> |
|
<p>Search through video content using natural language or images powered by <a href="https://github.com/pixeltable/pixeltable" target="_blank" style="color: #F25022; text-decoration: none; font-weight: bold;">Pixeltable</a>.</p> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.HTML('<div class="step-header"><h3>1. Insert video</h3></div>') |
|
|
|
video_file = gr.File(label="Upload Video", file_types=["video"]) |
|
process_button = gr.Button("Process Video", variant="primary") |
|
process_output = gr.Textbox(label="Status", lines=2) |
|
|
|
gr.HTML('<div class="step-header"><h3>2. Search video frames</h3></div>') |
|
|
|
search_type = gr.Radio( |
|
["Text", "Image"], |
|
label="Search Type", |
|
value="Text", |
|
info="Choose whether to search using text or an image" |
|
) |
|
text_input = gr.Textbox( |
|
label="Text Query", |
|
placeholder="Describe what you're looking for...", |
|
info="Example: 'person walking' or 'red car'" |
|
) |
|
image_input = gr.Image( |
|
label="Image Query", |
|
type="pil", |
|
visible=False, |
|
info="Upload an image to find similar frames" |
|
) |
|
num_results = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
value=5, |
|
step=1, |
|
label="Number of Results", |
|
info="How many matching frames to display" |
|
) |
|
search_button = gr.Button("Search", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
gr.HTML('<div class="step-header"><h3>3. Visualize results</h3></div>') |
|
|
|
results_gallery = gr.Gallery( |
|
label="Search Results", |
|
columns=3, |
|
allow_preview=True, |
|
object_fit="contain" |
|
) |
|
|
|
with gr.Accordion("Example Videos", open=False): |
|
gr.Markdown("Click one of the examples below to get started") |
|
gr.Examples( |
|
examples=[ |
|
["bangkok.mp4"], |
|
["lotr.mp4"], |
|
["mi.mp4"], |
|
], |
|
inputs=[video_file], |
|
outputs=[process_output], |
|
fn=process_video, |
|
cache_examples=True |
|
) |
|
|
|
|
|
def update_search_input(choice): |
|
return gr.update(visible=choice=="Text"), gr.update(visible=choice=="Image") |
|
|
|
search_type.change(update_search_input, search_type, [text_input, image_input]) |
|
|
|
process_button.click( |
|
process_video, |
|
inputs=[video_file], |
|
outputs=[process_output] |
|
) |
|
|
|
def perform_search(search_type, text_query, image_query, num_results): |
|
query = text_query if search_type == "Text" else image_query |
|
if query is None or (isinstance(query, str) and query.strip() == ""): |
|
return gr.Gallery(label="Please enter a valid search query") |
|
|
|
return similarity_search(query, search_type, num_results) |
|
|
|
search_button.click( |
|
perform_search, |
|
inputs=[search_type, text_input, image_input, num_results], |
|
outputs=[results_gallery] |
|
) |
|
|
|
|
|
search_type.change(lambda: None, None, None, _js="() => {document.activeElement.blur();}") |
|
text_input.submit( |
|
perform_search, |
|
inputs=[search_type, text_input, image_input, num_results], |
|
outputs=[results_gallery] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |